aboutsummaryrefslogtreecommitdiffstats
path: root/nip04.go
blob: 1f1c245e04e66010ec2949cc24052dad84f32ad3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package nostr

import (
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"crypto/sha256"
	"encoding/base64"
	"encoding/hex"
	"fmt"
	"strings"

	"github.com/btcsuite/btcd/btcec/v2"
)

// NIP04Encrypt encrypts a plaintext message for a recipient using NIP-04
// (ECDH shared secret + AES-256-CBC). The sender's private key and
// recipient's public key (64-char hex) are required.
// Returns the ciphertext in NIP-04 format: "<base64>?iv=<base64>"
func (k *Key) NIP04Encrypt(recipientPubHex, plaintext string) (string, error) {
	if k.priv == nil {
		return "", fmt.Errorf("cannot encrypt: public-only key")
	}

	shared, err := nip04SharedSecret(k.priv, recipientPubHex)
	if err != nil {
		return "", err
	}

	iv := make([]byte, 16)
	if _, err := rand.Read(iv); err != nil {
		return "", fmt.Errorf("generating iv: %w", err)
	}

	block, err := aes.NewCipher(shared)
	if err != nil {
		return "", fmt.Errorf("aes cipher: %w", err)
	}

	padded := pkcs7Pad([]byte(plaintext), aes.BlockSize)
	ct := make([]byte, len(padded))
	cipher.NewCBCEncrypter(block, iv).CryptBlocks(ct, padded)

	return base64.StdEncoding.EncodeToString(ct) + "?iv=" + base64.StdEncoding.EncodeToString(iv), nil
}

// NIP04Decrypt decrypts a NIP-04 ciphertext using the receiver's private key
// and sender's public key (64-char hex).
func (k *Key) NIP04Decrypt(senderPubHex, ciphertext string) (string, error) {
	if k.priv == nil {
		return "", fmt.Errorf("cannot decrypt: public-only key")
	}

	parts := strings.SplitN(ciphertext, "?iv=", 2)
	if len(parts) != 2 {
		return "", fmt.Errorf("invalid NIP-04 ciphertext format")
	}

	ct, err := base64.StdEncoding.DecodeString(parts[0])
	if err != nil {
		return "", fmt.Errorf("decode ciphertext: %w", err)
	}

	iv, err := base64.StdEncoding.DecodeString(parts[1])
	if err != nil {
		return "", fmt.Errorf("decode iv: %w", err)
	}

	shared, err := nip04SharedSecret(k.priv, senderPubHex)
	if err != nil {
		return "", err
	}

	block, err := aes.NewCipher(shared)
	if err != nil {
		return "", fmt.Errorf("aes cipher: %w", err)
	}

	if len(ct)%aes.BlockSize != 0 {
		return "", fmt.Errorf("ciphertext is not a multiple of block size")
	}

	plain := make([]byte, len(ct))
	cipher.NewCBCDecrypter(block, iv).CryptBlocks(plain, ct)

	unpadded, err := pkcs7Unpad(plain)
	if err != nil {
		return "", fmt.Errorf("unpad: %w", err)
	}

	return string(unpadded), nil
}

// nip04SharedSecret computes the ECDH shared secret per NIP-04:
// SHA256(privKey * recipientPubKey).x
func nip04SharedSecret(priv *btcec.PrivateKey, recipientPubHex string) ([]byte, error) {
	pubBytes, err := hex.DecodeString(recipientPubHex)
	if err != nil {
		return nil, fmt.Errorf("decode pubkey: %w", err)
	}

	// NIP-04 pubkeys are x-only (32 bytes, Schnorr). Prefix with 0x02 for
	// compressed SEC1 parsing.
	if len(pubBytes) == 32 {
		pubBytes = append([]byte{0x02}, pubBytes...)
	}

	pub, err := btcec.ParsePubKey(pubBytes)
	if err != nil {
		return nil, fmt.Errorf("parse pubkey: %w", err)
	}

	// ECDH: scalar multiply
	var point btcec.JacobianPoint
	pub.AsJacobian(&point)
	priv.Key.SetByteSlice(priv.Serialize())
	btcec.ScalarMultNonConst(&priv.Key, &point, &point)
	point.ToAffine()

	// Shared secret is SHA256 of the x-coordinate
	xBytes := point.X.Bytes()
	shared := sha256.Sum256(xBytes[:])

	return shared[:], nil
}

func pkcs7Pad(data []byte, blockSize int) []byte {
	pad := blockSize - len(data)%blockSize
	padded := make([]byte, len(data)+pad)
	copy(padded, data)
	for i := len(data); i < len(padded); i++ {
		padded[i] = byte(pad)
	}
	return padded
}

func pkcs7Unpad(data []byte) ([]byte, error) {
	if len(data) == 0 {
		return nil, fmt.Errorf("empty data")
	}
	pad := int(data[len(data)-1])
	if pad == 0 || pad > aes.BlockSize {
		return nil, fmt.Errorf("invalid padding size: %d", pad)
	}
	if len(data) < pad {
		return nil, fmt.Errorf("data shorter than padding")
	}
	return data[:len(data)-pad], nil
}