aboutsummaryrefslogtreecommitdiffstats
path: root/nip04.go
blob: f979bcda2090e3a7ce1946a7901d96cbb9c157e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package nostr

import (
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"crypto/sha256"
	"encoding/base64"
	"encoding/hex"
	"fmt"
	"strings"

	"code.northwest.io/nostr/internal/secp256k1"
)

// NIP04Encrypt encrypts plaintext for a recipient using NIP-04
// (ECDH shared secret + AES-256-CBC).
// recipientPubHex is the recipient's x-only public key as 64-char hex.
// Returns ciphertext in NIP-04 format: "<base64>?iv=<base64>"
func (k *Key) NIP04Encrypt(recipientPubHex, plaintext string) (string, error) {
	if k.priv == nil {
		return "", fmt.Errorf("cannot encrypt: public-only key")
	}

	shared, err := nip04SharedSecret(k.priv, recipientPubHex)
	if err != nil {
		return "", err
	}

	iv := make([]byte, 16)
	if _, err := rand.Read(iv); err != nil {
		return "", fmt.Errorf("generating iv: %w", err)
	}

	block, err := aes.NewCipher(shared)
	if err != nil {
		return "", fmt.Errorf("aes cipher: %w", err)
	}

	padded := pkcs7Pad([]byte(plaintext), aes.BlockSize)
	ct := make([]byte, len(padded))
	cipher.NewCBCEncrypter(block, iv).CryptBlocks(ct, padded)

	return base64.StdEncoding.EncodeToString(ct) + "?iv=" + base64.StdEncoding.EncodeToString(iv), nil
}

// NIP04Decrypt decrypts a NIP-04 ciphertext using the receiver's private key
// and the sender's x-only public key as 64-char hex.
func (k *Key) NIP04Decrypt(senderPubHex, ciphertext string) (string, error) {
	if k.priv == nil {
		return "", fmt.Errorf("cannot decrypt: public-only key")
	}

	parts := strings.SplitN(ciphertext, "?iv=", 2)
	if len(parts) != 2 {
		return "", fmt.Errorf("invalid NIP-04 ciphertext format")
	}

	ct, err := base64.StdEncoding.DecodeString(parts[0])
	if err != nil {
		return "", fmt.Errorf("decode ciphertext: %w", err)
	}

	iv, err := base64.StdEncoding.DecodeString(parts[1])
	if err != nil {
		return "", fmt.Errorf("decode iv: %w", err)
	}

	shared, err := nip04SharedSecret(k.priv, senderPubHex)
	if err != nil {
		return "", err
	}

	block, err := aes.NewCipher(shared)
	if err != nil {
		return "", fmt.Errorf("aes cipher: %w", err)
	}

	if len(ct)%aes.BlockSize != 0 {
		return "", fmt.Errorf("ciphertext is not a multiple of block size")
	}

	plain := make([]byte, len(ct))
	cipher.NewCBCDecrypter(block, iv).CryptBlocks(plain, ct)

	unpadded, err := pkcs7Unpad(plain)
	if err != nil {
		return "", fmt.Errorf("unpad: %w", err)
	}

	return string(unpadded), nil
}

// nip04SharedSecret computes the NIP-04 ECDH shared secret:
// SHA256( (privKey * recipientPubKey).x )
func nip04SharedSecret(priv *secp256k1.PrivateKey, recipientPubHex string) ([]byte, error) {
	pubBytes, err := hex.DecodeString(recipientPubHex)
	if err != nil {
		return nil, fmt.Errorf("decode pubkey: %w", err)
	}

	// Expect 32-byte x-only key
	pub, err := secp256k1.ParsePublicKeyXOnly(pubBytes)
	if err != nil {
		return nil, fmt.Errorf("parse pubkey: %w", err)
	}

	// ECDH: scalar multiply — shared point = priv.D * recipientPub
	shared := pub.Point.ScalarMul(priv.D)

	// Shared secret = SHA256(x-coordinate, big-endian, 32 bytes)
	xBytes := shared.XBytes()
	if xBytes == nil {
		return nil, fmt.Errorf("ECDH result is point at infinity")
	}

	h := sha256.Sum256(xBytes)
	return h[:], nil
}

func pkcs7Pad(data []byte, blockSize int) []byte {
	pad := blockSize - len(data)%blockSize
	padded := make([]byte, len(data)+pad)
	copy(padded, data)
	for i := len(data); i < len(padded); i++ {
		padded[i] = byte(pad)
	}
	return padded
}

func pkcs7Unpad(data []byte) ([]byte, error) {
	if len(data) == 0 {
		return nil, fmt.Errorf("empty data")
	}
	pad := int(data[len(data)-1])
	if pad == 0 || pad > aes.BlockSize {
		return nil, fmt.Errorf("invalid padding size: %d", pad)
	}
	if len(data) < pad {
		return nil, fmt.Errorf("data shorter than padding")
	}
	return data[:len(data)-pad], nil
}