diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/nostr/bech32.go | 162 | ||||
| -rw-r--r-- | internal/nostr/bech32_test.go | 139 | ||||
| -rw-r--r-- | internal/nostr/envelope.go | 262 | ||||
| -rw-r--r-- | internal/nostr/envelope_test.go | 416 | ||||
| -rw-r--r-- | internal/nostr/event.go | 72 | ||||
| -rw-r--r-- | internal/nostr/event_test.go | 194 | ||||
| -rw-r--r-- | internal/nostr/example_test.go | 85 | ||||
| -rw-r--r-- | internal/nostr/filter.go | 224 | ||||
| -rw-r--r-- | internal/nostr/filter_test.go | 415 | ||||
| -rw-r--r-- | internal/nostr/keys.go | 217 | ||||
| -rw-r--r-- | internal/nostr/keys_test.go | 333 | ||||
| -rw-r--r-- | internal/nostr/kinds.go | 51 | ||||
| -rw-r--r-- | internal/nostr/kinds_test.go | 128 | ||||
| -rw-r--r-- | internal/nostr/relay.go | 305 | ||||
| -rw-r--r-- | internal/nostr/relay_test.go | 326 | ||||
| -rw-r--r-- | internal/nostr/tags.go | 64 | ||||
| -rw-r--r-- | internal/nostr/tags_test.go | 158 | ||||
| -rw-r--r-- | internal/websocket/websocket.go | 297 |
18 files changed, 3848 insertions, 0 deletions
diff --git a/internal/nostr/bech32.go b/internal/nostr/bech32.go new file mode 100644 index 0000000..c8b1293 --- /dev/null +++ b/internal/nostr/bech32.go | |||
| @@ -0,0 +1,162 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "fmt" | ||
| 5 | "strings" | ||
| 6 | ) | ||
| 7 | |||
| 8 | // Bech32 encoding/decoding for NIP-19 (npub, nsec, note, etc.) | ||
| 9 | // Implements BIP-173 bech32 encoding. | ||
| 10 | |||
| 11 | const bech32Alphabet = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" | ||
| 12 | |||
| 13 | var bech32AlphabetMap [256]int8 | ||
| 14 | |||
| 15 | func init() { | ||
| 16 | for i := range bech32AlphabetMap { | ||
| 17 | bech32AlphabetMap[i] = -1 | ||
| 18 | } | ||
| 19 | for i, c := range bech32Alphabet { | ||
| 20 | bech32AlphabetMap[c] = int8(i) | ||
| 21 | } | ||
| 22 | } | ||
| 23 | |||
| 24 | // bech32Polymod computes the BCH checksum. | ||
| 25 | func bech32Polymod(values []int) int { | ||
| 26 | gen := []int{0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3} | ||
| 27 | chk := 1 | ||
| 28 | for _, v := range values { | ||
| 29 | top := chk >> 25 | ||
| 30 | chk = (chk&0x1ffffff)<<5 ^ v | ||
| 31 | for i := 0; i < 5; i++ { | ||
| 32 | if (top>>i)&1 == 1 { | ||
| 33 | chk ^= gen[i] | ||
| 34 | } | ||
| 35 | } | ||
| 36 | } | ||
| 37 | return chk | ||
| 38 | } | ||
| 39 | |||
| 40 | // bech32HRPExpand expands the human-readable part for checksum computation. | ||
| 41 | func bech32HRPExpand(hrp string) []int { | ||
| 42 | result := make([]int, len(hrp)*2+1) | ||
| 43 | for i, c := range hrp { | ||
| 44 | result[i] = int(c) >> 5 | ||
| 45 | result[i+len(hrp)+1] = int(c) & 31 | ||
| 46 | } | ||
| 47 | return result | ||
| 48 | } | ||
| 49 | |||
| 50 | // bech32CreateChecksum creates the 6-character checksum. | ||
| 51 | func bech32CreateChecksum(hrp string, data []int) []int { | ||
| 52 | values := append(bech32HRPExpand(hrp), data...) | ||
| 53 | values = append(values, []int{0, 0, 0, 0, 0, 0}...) | ||
| 54 | polymod := bech32Polymod(values) ^ 1 | ||
| 55 | checksum := make([]int, 6) | ||
| 56 | for i := 0; i < 6; i++ { | ||
| 57 | checksum[i] = (polymod >> (5 * (5 - i))) & 31 | ||
| 58 | } | ||
| 59 | return checksum | ||
| 60 | } | ||
| 61 | |||
| 62 | // bech32VerifyChecksum verifies the checksum of bech32 data. | ||
| 63 | func bech32VerifyChecksum(hrp string, data []int) bool { | ||
| 64 | return bech32Polymod(append(bech32HRPExpand(hrp), data...)) == 1 | ||
| 65 | } | ||
| 66 | |||
| 67 | // convertBits converts between bit groups. | ||
| 68 | func convertBits(data []byte, fromBits, toBits int, pad bool) ([]int, error) { | ||
| 69 | acc := 0 | ||
| 70 | bits := 0 | ||
| 71 | result := make([]int, 0, len(data)*fromBits/toBits+1) | ||
| 72 | maxv := (1 << toBits) - 1 | ||
| 73 | |||
| 74 | for _, value := range data { | ||
| 75 | acc = (acc << fromBits) | int(value) | ||
| 76 | bits += fromBits | ||
| 77 | for bits >= toBits { | ||
| 78 | bits -= toBits | ||
| 79 | result = append(result, (acc>>bits)&maxv) | ||
| 80 | } | ||
| 81 | } | ||
| 82 | |||
| 83 | if pad { | ||
| 84 | if bits > 0 { | ||
| 85 | result = append(result, (acc<<(toBits-bits))&maxv) | ||
| 86 | } | ||
| 87 | } else if bits >= fromBits || ((acc<<(toBits-bits))&maxv) != 0 { | ||
| 88 | return nil, fmt.Errorf("invalid padding") | ||
| 89 | } | ||
| 90 | |||
| 91 | return result, nil | ||
| 92 | } | ||
| 93 | |||
| 94 | // Bech32Encode encodes data with the given human-readable prefix. | ||
| 95 | func Bech32Encode(hrp string, data []byte) (string, error) { | ||
| 96 | values, err := convertBits(data, 8, 5, true) | ||
| 97 | if err != nil { | ||
| 98 | return "", err | ||
| 99 | } | ||
| 100 | |||
| 101 | checksum := bech32CreateChecksum(hrp, values) | ||
| 102 | combined := append(values, checksum...) | ||
| 103 | |||
| 104 | var result strings.Builder | ||
| 105 | result.WriteString(hrp) | ||
| 106 | result.WriteByte('1') | ||
| 107 | for _, v := range combined { | ||
| 108 | result.WriteByte(bech32Alphabet[v]) | ||
| 109 | } | ||
| 110 | |||
| 111 | return result.String(), nil | ||
| 112 | } | ||
| 113 | |||
| 114 | // Bech32Decode decodes a bech32 string, returning the HRP and data. | ||
| 115 | func Bech32Decode(s string) (string, []byte, error) { | ||
| 116 | s = strings.ToLower(s) | ||
| 117 | |||
| 118 | pos := strings.LastIndexByte(s, '1') | ||
| 119 | if pos < 1 || pos+7 > len(s) { | ||
| 120 | return "", nil, fmt.Errorf("invalid bech32 string") | ||
| 121 | } | ||
| 122 | |||
| 123 | hrp := s[:pos] | ||
| 124 | dataStr := s[pos+1:] | ||
| 125 | |||
| 126 | data := make([]int, len(dataStr)) | ||
| 127 | for i, c := range dataStr { | ||
| 128 | val := bech32AlphabetMap[c] | ||
| 129 | if val == -1 { | ||
| 130 | return "", nil, fmt.Errorf("invalid character: %c", c) | ||
| 131 | } | ||
| 132 | data[i] = int(val) | ||
| 133 | } | ||
| 134 | |||
| 135 | if !bech32VerifyChecksum(hrp, data) { | ||
| 136 | return "", nil, fmt.Errorf("invalid checksum") | ||
| 137 | } | ||
| 138 | |||
| 139 | // Remove checksum | ||
| 140 | data = data[:len(data)-6] | ||
| 141 | |||
| 142 | // Convert from 5-bit to 8-bit | ||
| 143 | result, err := convertBits(intSliceToBytes(data), 5, 8, false) | ||
| 144 | if err != nil { | ||
| 145 | return "", nil, err | ||
| 146 | } | ||
| 147 | |||
| 148 | bytes := make([]byte, len(result)) | ||
| 149 | for i, v := range result { | ||
| 150 | bytes[i] = byte(v) | ||
| 151 | } | ||
| 152 | |||
| 153 | return hrp, bytes, nil | ||
| 154 | } | ||
| 155 | |||
| 156 | func intSliceToBytes(data []int) []byte { | ||
| 157 | result := make([]byte, len(data)) | ||
| 158 | for i, v := range data { | ||
| 159 | result[i] = byte(v) | ||
| 160 | } | ||
| 161 | return result | ||
| 162 | } | ||
diff --git a/internal/nostr/bech32_test.go b/internal/nostr/bech32_test.go new file mode 100644 index 0000000..fb1260b --- /dev/null +++ b/internal/nostr/bech32_test.go | |||
| @@ -0,0 +1,139 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "bytes" | ||
| 5 | "encoding/hex" | ||
| 6 | "testing" | ||
| 7 | ) | ||
| 8 | |||
| 9 | func TestBech32Encode(t *testing.T) { | ||
| 10 | // Test vector: 32 bytes of data | ||
| 11 | data, _ := hex.DecodeString("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef") | ||
| 12 | |||
| 13 | encoded, err := Bech32Encode("npub", data) | ||
| 14 | if err != nil { | ||
| 15 | t.Fatalf("Bech32Encode() error = %v", err) | ||
| 16 | } | ||
| 17 | |||
| 18 | if encoded[:5] != "npub1" { | ||
| 19 | t.Errorf("Encoded string should start with 'npub1', got %s", encoded[:5]) | ||
| 20 | } | ||
| 21 | |||
| 22 | // Decode it back | ||
| 23 | hrp, decoded, err := Bech32Decode(encoded) | ||
| 24 | if err != nil { | ||
| 25 | t.Fatalf("Bech32Decode() error = %v", err) | ||
| 26 | } | ||
| 27 | |||
| 28 | if hrp != "npub" { | ||
| 29 | t.Errorf("HRP = %s, want npub", hrp) | ||
| 30 | } | ||
| 31 | |||
| 32 | if !bytes.Equal(decoded, data) { | ||
| 33 | t.Errorf("Round-trip failed: got %x, want %x", decoded, data) | ||
| 34 | } | ||
| 35 | } | ||
| 36 | |||
| 37 | func TestBech32EncodeNsec(t *testing.T) { | ||
| 38 | data, _ := hex.DecodeString("deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") | ||
| 39 | |||
| 40 | encoded, err := Bech32Encode("nsec", data) | ||
| 41 | if err != nil { | ||
| 42 | t.Fatalf("Bech32Encode() error = %v", err) | ||
| 43 | } | ||
| 44 | |||
| 45 | if encoded[:5] != "nsec1" { | ||
| 46 | t.Errorf("Encoded string should start with 'nsec1', got %s", encoded[:5]) | ||
| 47 | } | ||
| 48 | |||
| 49 | // Decode it back | ||
| 50 | hrp, decoded, err := Bech32Decode(encoded) | ||
| 51 | if err != nil { | ||
| 52 | t.Fatalf("Bech32Decode() error = %v", err) | ||
| 53 | } | ||
| 54 | |||
| 55 | if hrp != "nsec" { | ||
| 56 | t.Errorf("HRP = %s, want nsec", hrp) | ||
| 57 | } | ||
| 58 | |||
| 59 | if !bytes.Equal(decoded, data) { | ||
| 60 | t.Errorf("Round-trip failed") | ||
| 61 | } | ||
| 62 | } | ||
| 63 | |||
| 64 | func TestBech32DecodeErrors(t *testing.T) { | ||
| 65 | tests := []struct { | ||
| 66 | name string | ||
| 67 | input string | ||
| 68 | }{ | ||
| 69 | {"no separator", "npubabcdef"}, | ||
| 70 | {"empty data", "npub1"}, | ||
| 71 | {"invalid character", "npub1o"}, // 'o' is not in bech32 alphabet | ||
| 72 | {"invalid checksum", "npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqpqqqqq"}, | ||
| 73 | } | ||
| 74 | |||
| 75 | for _, tt := range tests { | ||
| 76 | t.Run(tt.name, func(t *testing.T) { | ||
| 77 | _, _, err := Bech32Decode(tt.input) | ||
| 78 | if err == nil { | ||
| 79 | t.Error("Bech32Decode() expected error, got nil") | ||
| 80 | } | ||
| 81 | }) | ||
| 82 | } | ||
| 83 | } | ||
| 84 | |||
| 85 | func TestBech32KnownVectors(t *testing.T) { | ||
| 86 | // Test with known nostr npub/nsec values | ||
| 87 | // These can be verified with other nostr tools | ||
| 88 | |||
| 89 | // Generate a key and verify round-trip | ||
| 90 | key, err := GenerateKey() | ||
| 91 | if err != nil { | ||
| 92 | t.Fatalf("GenerateKey() error = %v", err) | ||
| 93 | } | ||
| 94 | |||
| 95 | npub := key.Npub() | ||
| 96 | nsec := key.Nsec() | ||
| 97 | |||
| 98 | // Verify npub decodes to public key | ||
| 99 | hrp, pubBytes, err := Bech32Decode(npub) | ||
| 100 | if err != nil { | ||
| 101 | t.Fatalf("Bech32Decode(npub) error = %v", err) | ||
| 102 | } | ||
| 103 | if hrp != "npub" { | ||
| 104 | t.Errorf("npub HRP = %s, want npub", hrp) | ||
| 105 | } | ||
| 106 | if hex.EncodeToString(pubBytes) != key.Public() { | ||
| 107 | t.Error("npub does not decode to correct public key") | ||
| 108 | } | ||
| 109 | |||
| 110 | // Verify nsec decodes to private key | ||
| 111 | hrp, privBytes, err := Bech32Decode(nsec) | ||
| 112 | if err != nil { | ||
| 113 | t.Fatalf("Bech32Decode(nsec) error = %v", err) | ||
| 114 | } | ||
| 115 | if hrp != "nsec" { | ||
| 116 | t.Errorf("nsec HRP = %s, want nsec", hrp) | ||
| 117 | } | ||
| 118 | if hex.EncodeToString(privBytes) != key.Private() { | ||
| 119 | t.Error("nsec does not decode to correct private key") | ||
| 120 | } | ||
| 121 | } | ||
| 122 | |||
| 123 | func TestBech32CaseInsensitive(t *testing.T) { | ||
| 124 | data, _ := hex.DecodeString("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef") | ||
| 125 | encoded, _ := Bech32Encode("npub", data) | ||
| 126 | |||
| 127 | // Test uppercase | ||
| 128 | upper := "NPUB1" + encoded[5:] | ||
| 129 | hrp, decoded, err := Bech32Decode(upper) | ||
| 130 | if err != nil { | ||
| 131 | t.Fatalf("Bech32Decode(uppercase) error = %v", err) | ||
| 132 | } | ||
| 133 | if hrp != "npub" { | ||
| 134 | t.Errorf("HRP = %s, want npub", hrp) | ||
| 135 | } | ||
| 136 | if !bytes.Equal(decoded, data) { | ||
| 137 | t.Error("Uppercase decode failed") | ||
| 138 | } | ||
| 139 | } | ||
diff --git a/internal/nostr/envelope.go b/internal/nostr/envelope.go new file mode 100644 index 0000000..d395efa --- /dev/null +++ b/internal/nostr/envelope.go | |||
| @@ -0,0 +1,262 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "encoding/json" | ||
| 5 | "fmt" | ||
| 6 | ) | ||
| 7 | |||
| 8 | // Envelope represents a Nostr protocol message. | ||
| 9 | type Envelope interface { | ||
| 10 | Label() string | ||
| 11 | MarshalJSON() ([]byte, error) | ||
| 12 | } | ||
| 13 | |||
| 14 | // EventEnvelope wraps an event for the EVENT message. | ||
| 15 | // Used both client→relay (publish) and relay→client (subscription). | ||
| 16 | type EventEnvelope struct { | ||
| 17 | SubscriptionID string // Only for relay→client messages | ||
| 18 | Event *Event | ||
| 19 | } | ||
| 20 | |||
| 21 | func (e EventEnvelope) Label() string { return "EVENT" } | ||
| 22 | |||
| 23 | func (e EventEnvelope) MarshalJSON() ([]byte, error) { | ||
| 24 | if e.SubscriptionID != "" { | ||
| 25 | return json.Marshal([]interface{}{"EVENT", e.SubscriptionID, e.Event}) | ||
| 26 | } | ||
| 27 | return json.Marshal([]interface{}{"EVENT", e.Event}) | ||
| 28 | } | ||
| 29 | |||
| 30 | // ReqEnvelope represents a REQ message (client→relay). | ||
| 31 | type ReqEnvelope struct { | ||
| 32 | SubscriptionID string | ||
| 33 | Filters []Filter | ||
| 34 | } | ||
| 35 | |||
| 36 | func (e ReqEnvelope) Label() string { return "REQ" } | ||
| 37 | |||
| 38 | func (e ReqEnvelope) MarshalJSON() ([]byte, error) { | ||
| 39 | arr := make([]interface{}, 2+len(e.Filters)) | ||
| 40 | arr[0] = "REQ" | ||
| 41 | arr[1] = e.SubscriptionID | ||
| 42 | for i, f := range e.Filters { | ||
| 43 | arr[2+i] = f | ||
| 44 | } | ||
| 45 | return json.Marshal(arr) | ||
| 46 | } | ||
| 47 | |||
| 48 | // CloseEnvelope represents a CLOSE message (client→relay). | ||
| 49 | type CloseEnvelope struct { | ||
| 50 | SubscriptionID string | ||
| 51 | } | ||
| 52 | |||
| 53 | func (e CloseEnvelope) Label() string { return "CLOSE" } | ||
| 54 | |||
| 55 | func (e CloseEnvelope) MarshalJSON() ([]byte, error) { | ||
| 56 | return json.Marshal([]interface{}{"CLOSE", e.SubscriptionID}) | ||
| 57 | } | ||
| 58 | |||
| 59 | // OKEnvelope represents an OK message (relay→client). | ||
| 60 | type OKEnvelope struct { | ||
| 61 | EventID string | ||
| 62 | OK bool | ||
| 63 | Message string | ||
| 64 | } | ||
| 65 | |||
| 66 | func (e OKEnvelope) Label() string { return "OK" } | ||
| 67 | |||
| 68 | func (e OKEnvelope) MarshalJSON() ([]byte, error) { | ||
| 69 | return json.Marshal([]interface{}{"OK", e.EventID, e.OK, e.Message}) | ||
| 70 | } | ||
| 71 | |||
| 72 | // EOSEEnvelope represents an EOSE (End of Stored Events) message (relay→client). | ||
| 73 | type EOSEEnvelope struct { | ||
| 74 | SubscriptionID string | ||
| 75 | } | ||
| 76 | |||
| 77 | func (e EOSEEnvelope) Label() string { return "EOSE" } | ||
| 78 | |||
| 79 | func (e EOSEEnvelope) MarshalJSON() ([]byte, error) { | ||
| 80 | return json.Marshal([]interface{}{"EOSE", e.SubscriptionID}) | ||
| 81 | } | ||
| 82 | |||
| 83 | // ClosedEnvelope represents a CLOSED message (relay→client). | ||
| 84 | type ClosedEnvelope struct { | ||
| 85 | SubscriptionID string | ||
| 86 | Message string | ||
| 87 | } | ||
| 88 | |||
| 89 | func (e ClosedEnvelope) Label() string { return "CLOSED" } | ||
| 90 | |||
| 91 | func (e ClosedEnvelope) MarshalJSON() ([]byte, error) { | ||
| 92 | return json.Marshal([]interface{}{"CLOSED", e.SubscriptionID, e.Message}) | ||
| 93 | } | ||
| 94 | |||
| 95 | // NoticeEnvelope represents a NOTICE message (relay→client). | ||
| 96 | type NoticeEnvelope struct { | ||
| 97 | Message string | ||
| 98 | } | ||
| 99 | |||
| 100 | func (e NoticeEnvelope) Label() string { return "NOTICE" } | ||
| 101 | |||
| 102 | func (e NoticeEnvelope) MarshalJSON() ([]byte, error) { | ||
| 103 | return json.Marshal([]interface{}{"NOTICE", e.Message}) | ||
| 104 | } | ||
| 105 | |||
| 106 | // ParseEnvelope parses a raw JSON message into the appropriate envelope type. | ||
| 107 | func ParseEnvelope(data []byte) (Envelope, error) { | ||
| 108 | var arr []json.RawMessage | ||
| 109 | if err := json.Unmarshal(data, &arr); err != nil { | ||
| 110 | return nil, fmt.Errorf("invalid envelope: %w", err) | ||
| 111 | } | ||
| 112 | |||
| 113 | if len(arr) < 2 { | ||
| 114 | return nil, fmt.Errorf("envelope too short") | ||
| 115 | } | ||
| 116 | |||
| 117 | var label string | ||
| 118 | if err := json.Unmarshal(arr[0], &label); err != nil { | ||
| 119 | return nil, fmt.Errorf("invalid envelope label: %w", err) | ||
| 120 | } | ||
| 121 | |||
| 122 | switch label { | ||
| 123 | case "EVENT": | ||
| 124 | return parseEventEnvelope(arr) | ||
| 125 | case "REQ": | ||
| 126 | return parseReqEnvelope(arr) | ||
| 127 | case "CLOSE": | ||
| 128 | return parseCloseEnvelope(arr) | ||
| 129 | case "OK": | ||
| 130 | return parseOKEnvelope(arr) | ||
| 131 | case "EOSE": | ||
| 132 | return parseEOSEEnvelope(arr) | ||
| 133 | case "CLOSED": | ||
| 134 | return parseClosedEnvelope(arr) | ||
| 135 | case "NOTICE": | ||
| 136 | return parseNoticeEnvelope(arr) | ||
| 137 | default: | ||
| 138 | return nil, fmt.Errorf("unknown envelope type: %s", label) | ||
| 139 | } | ||
| 140 | } | ||
| 141 | |||
| 142 | func parseEventEnvelope(arr []json.RawMessage) (*EventEnvelope, error) { | ||
| 143 | env := &EventEnvelope{} | ||
| 144 | |||
| 145 | if len(arr) == 2 { | ||
| 146 | // Client→relay: ["EVENT", event] | ||
| 147 | var event Event | ||
| 148 | if err := json.Unmarshal(arr[1], &event); err != nil { | ||
| 149 | return nil, fmt.Errorf("invalid event: %w", err) | ||
| 150 | } | ||
| 151 | env.Event = &event | ||
| 152 | } else if len(arr) == 3 { | ||
| 153 | // Relay→client: ["EVENT", subscription_id, event] | ||
| 154 | if err := json.Unmarshal(arr[1], &env.SubscriptionID); err != nil { | ||
| 155 | return nil, fmt.Errorf("invalid subscription ID: %w", err) | ||
| 156 | } | ||
| 157 | var event Event | ||
| 158 | if err := json.Unmarshal(arr[2], &event); err != nil { | ||
| 159 | return nil, fmt.Errorf("invalid event: %w", err) | ||
| 160 | } | ||
| 161 | env.Event = &event | ||
| 162 | } else { | ||
| 163 | return nil, fmt.Errorf("invalid EVENT envelope length: %d", len(arr)) | ||
| 164 | } | ||
| 165 | |||
| 166 | return env, nil | ||
| 167 | } | ||
| 168 | |||
| 169 | func parseReqEnvelope(arr []json.RawMessage) (*ReqEnvelope, error) { | ||
| 170 | if len(arr) < 3 { | ||
| 171 | return nil, fmt.Errorf("REQ envelope must have at least 3 elements") | ||
| 172 | } | ||
| 173 | |||
| 174 | env := &ReqEnvelope{} | ||
| 175 | if err := json.Unmarshal(arr[1], &env.SubscriptionID); err != nil { | ||
| 176 | return nil, fmt.Errorf("invalid subscription ID: %w", err) | ||
| 177 | } | ||
| 178 | |||
| 179 | for i := 2; i < len(arr); i++ { | ||
| 180 | var filter Filter | ||
| 181 | if err := json.Unmarshal(arr[i], &filter); err != nil { | ||
| 182 | return nil, fmt.Errorf("invalid filter at index %d: %w", i-2, err) | ||
| 183 | } | ||
| 184 | env.Filters = append(env.Filters, filter) | ||
| 185 | } | ||
| 186 | |||
| 187 | return env, nil | ||
| 188 | } | ||
| 189 | |||
| 190 | func parseCloseEnvelope(arr []json.RawMessage) (*CloseEnvelope, error) { | ||
| 191 | if len(arr) != 2 { | ||
| 192 | return nil, fmt.Errorf("CLOSE envelope must have exactly 2 elements") | ||
| 193 | } | ||
| 194 | |||
| 195 | env := &CloseEnvelope{} | ||
| 196 | if err := json.Unmarshal(arr[1], &env.SubscriptionID); err != nil { | ||
| 197 | return nil, fmt.Errorf("invalid subscription ID: %w", err) | ||
| 198 | } | ||
| 199 | |||
| 200 | return env, nil | ||
| 201 | } | ||
| 202 | |||
| 203 | func parseOKEnvelope(arr []json.RawMessage) (*OKEnvelope, error) { | ||
| 204 | if len(arr) != 4 { | ||
| 205 | return nil, fmt.Errorf("OK envelope must have exactly 4 elements") | ||
| 206 | } | ||
| 207 | |||
| 208 | env := &OKEnvelope{} | ||
| 209 | if err := json.Unmarshal(arr[1], &env.EventID); err != nil { | ||
| 210 | return nil, fmt.Errorf("invalid event ID: %w", err) | ||
| 211 | } | ||
| 212 | if err := json.Unmarshal(arr[2], &env.OK); err != nil { | ||
| 213 | return nil, fmt.Errorf("invalid OK status: %w", err) | ||
| 214 | } | ||
| 215 | if err := json.Unmarshal(arr[3], &env.Message); err != nil { | ||
| 216 | return nil, fmt.Errorf("invalid message: %w", err) | ||
| 217 | } | ||
| 218 | |||
| 219 | return env, nil | ||
| 220 | } | ||
| 221 | |||
| 222 | func parseEOSEEnvelope(arr []json.RawMessage) (*EOSEEnvelope, error) { | ||
| 223 | if len(arr) != 2 { | ||
| 224 | return nil, fmt.Errorf("EOSE envelope must have exactly 2 elements") | ||
| 225 | } | ||
| 226 | |||
| 227 | env := &EOSEEnvelope{} | ||
| 228 | if err := json.Unmarshal(arr[1], &env.SubscriptionID); err != nil { | ||
| 229 | return nil, fmt.Errorf("invalid subscription ID: %w", err) | ||
| 230 | } | ||
| 231 | |||
| 232 | return env, nil | ||
| 233 | } | ||
| 234 | |||
| 235 | func parseClosedEnvelope(arr []json.RawMessage) (*ClosedEnvelope, error) { | ||
| 236 | if len(arr) != 3 { | ||
| 237 | return nil, fmt.Errorf("CLOSED envelope must have exactly 3 elements") | ||
| 238 | } | ||
| 239 | |||
| 240 | env := &ClosedEnvelope{} | ||
| 241 | if err := json.Unmarshal(arr[1], &env.SubscriptionID); err != nil { | ||
| 242 | return nil, fmt.Errorf("invalid subscription ID: %w", err) | ||
| 243 | } | ||
| 244 | if err := json.Unmarshal(arr[2], &env.Message); err != nil { | ||
| 245 | return nil, fmt.Errorf("invalid message: %w", err) | ||
| 246 | } | ||
| 247 | |||
| 248 | return env, nil | ||
| 249 | } | ||
| 250 | |||
| 251 | func parseNoticeEnvelope(arr []json.RawMessage) (*NoticeEnvelope, error) { | ||
| 252 | if len(arr) != 2 { | ||
| 253 | return nil, fmt.Errorf("NOTICE envelope must have exactly 2 elements") | ||
| 254 | } | ||
| 255 | |||
| 256 | env := &NoticeEnvelope{} | ||
| 257 | if err := json.Unmarshal(arr[1], &env.Message); err != nil { | ||
| 258 | return nil, fmt.Errorf("invalid message: %w", err) | ||
| 259 | } | ||
| 260 | |||
| 261 | return env, nil | ||
| 262 | } | ||
diff --git a/internal/nostr/envelope_test.go b/internal/nostr/envelope_test.go new file mode 100644 index 0000000..8f79ad5 --- /dev/null +++ b/internal/nostr/envelope_test.go | |||
| @@ -0,0 +1,416 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "encoding/json" | ||
| 5 | "testing" | ||
| 6 | ) | ||
| 7 | |||
| 8 | func TestEventEnvelopeMarshalJSON(t *testing.T) { | ||
| 9 | event := &Event{ | ||
| 10 | ID: "abc123", | ||
| 11 | PubKey: "def456", | ||
| 12 | CreatedAt: 1704067200, | ||
| 13 | Kind: 1, | ||
| 14 | Tags: Tags{}, | ||
| 15 | Content: "Hello", | ||
| 16 | Sig: "sig789", | ||
| 17 | } | ||
| 18 | |||
| 19 | t.Run("client to relay", func(t *testing.T) { | ||
| 20 | env := EventEnvelope{Event: event} | ||
| 21 | data, err := env.MarshalJSON() | ||
| 22 | if err != nil { | ||
| 23 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 24 | } | ||
| 25 | |||
| 26 | var arr []json.RawMessage | ||
| 27 | if err := json.Unmarshal(data, &arr); err != nil { | ||
| 28 | t.Fatalf("Invalid JSON: %v", err) | ||
| 29 | } | ||
| 30 | |||
| 31 | if len(arr) != 2 { | ||
| 32 | t.Errorf("Array length = %d, want 2", len(arr)) | ||
| 33 | } | ||
| 34 | |||
| 35 | var label string | ||
| 36 | json.Unmarshal(arr[0], &label) | ||
| 37 | if label != "EVENT" { | ||
| 38 | t.Errorf("Label = %s, want EVENT", label) | ||
| 39 | } | ||
| 40 | }) | ||
| 41 | |||
| 42 | t.Run("relay to client", func(t *testing.T) { | ||
| 43 | env := EventEnvelope{SubscriptionID: "sub1", Event: event} | ||
| 44 | data, err := env.MarshalJSON() | ||
| 45 | if err != nil { | ||
| 46 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 47 | } | ||
| 48 | |||
| 49 | var arr []json.RawMessage | ||
| 50 | if err := json.Unmarshal(data, &arr); err != nil { | ||
| 51 | t.Fatalf("Invalid JSON: %v", err) | ||
| 52 | } | ||
| 53 | |||
| 54 | if len(arr) != 3 { | ||
| 55 | t.Errorf("Array length = %d, want 3", len(arr)) | ||
| 56 | } | ||
| 57 | }) | ||
| 58 | } | ||
| 59 | |||
| 60 | func TestReqEnvelopeMarshalJSON(t *testing.T) { | ||
| 61 | env := ReqEnvelope{ | ||
| 62 | SubscriptionID: "sub1", | ||
| 63 | Filters: []Filter{ | ||
| 64 | {Kinds: []int{1}}, | ||
| 65 | {Authors: []string{"abc123"}}, | ||
| 66 | }, | ||
| 67 | } | ||
| 68 | |||
| 69 | data, err := env.MarshalJSON() | ||
| 70 | if err != nil { | ||
| 71 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 72 | } | ||
| 73 | |||
| 74 | var arr []json.RawMessage | ||
| 75 | if err := json.Unmarshal(data, &arr); err != nil { | ||
| 76 | t.Fatalf("Invalid JSON: %v", err) | ||
| 77 | } | ||
| 78 | |||
| 79 | if len(arr) != 4 { // ["REQ", "sub1", filter1, filter2] | ||
| 80 | t.Errorf("Array length = %d, want 4", len(arr)) | ||
| 81 | } | ||
| 82 | |||
| 83 | var label string | ||
| 84 | json.Unmarshal(arr[0], &label) | ||
| 85 | if label != "REQ" { | ||
| 86 | t.Errorf("Label = %s, want REQ", label) | ||
| 87 | } | ||
| 88 | |||
| 89 | var subID string | ||
| 90 | json.Unmarshal(arr[1], &subID) | ||
| 91 | if subID != "sub1" { | ||
| 92 | t.Errorf("SubscriptionID = %s, want sub1", subID) | ||
| 93 | } | ||
| 94 | } | ||
| 95 | |||
| 96 | func TestCloseEnvelopeMarshalJSON(t *testing.T) { | ||
| 97 | env := CloseEnvelope{SubscriptionID: "sub1"} | ||
| 98 | data, err := env.MarshalJSON() | ||
| 99 | if err != nil { | ||
| 100 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 101 | } | ||
| 102 | |||
| 103 | var arr []interface{} | ||
| 104 | if err := json.Unmarshal(data, &arr); err != nil { | ||
| 105 | t.Fatalf("Invalid JSON: %v", err) | ||
| 106 | } | ||
| 107 | |||
| 108 | if len(arr) != 2 { | ||
| 109 | t.Errorf("Array length = %d, want 2", len(arr)) | ||
| 110 | } | ||
| 111 | if arr[0] != "CLOSE" { | ||
| 112 | t.Errorf("Label = %v, want CLOSE", arr[0]) | ||
| 113 | } | ||
| 114 | if arr[1] != "sub1" { | ||
| 115 | t.Errorf("SubscriptionID = %v, want sub1", arr[1]) | ||
| 116 | } | ||
| 117 | } | ||
| 118 | |||
| 119 | func TestOKEnvelopeMarshalJSON(t *testing.T) { | ||
| 120 | env := OKEnvelope{ | ||
| 121 | EventID: "event123", | ||
| 122 | OK: true, | ||
| 123 | Message: "accepted", | ||
| 124 | } | ||
| 125 | |||
| 126 | data, err := env.MarshalJSON() | ||
| 127 | if err != nil { | ||
| 128 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 129 | } | ||
| 130 | |||
| 131 | var arr []interface{} | ||
| 132 | if err := json.Unmarshal(data, &arr); err != nil { | ||
| 133 | t.Fatalf("Invalid JSON: %v", err) | ||
| 134 | } | ||
| 135 | |||
| 136 | if len(arr) != 4 { | ||
| 137 | t.Errorf("Array length = %d, want 4", len(arr)) | ||
| 138 | } | ||
| 139 | if arr[0] != "OK" { | ||
| 140 | t.Errorf("Label = %v, want OK", arr[0]) | ||
| 141 | } | ||
| 142 | if arr[1] != "event123" { | ||
| 143 | t.Errorf("EventID = %v, want event123", arr[1]) | ||
| 144 | } | ||
| 145 | if arr[2] != true { | ||
| 146 | t.Errorf("OK = %v, want true", arr[2]) | ||
| 147 | } | ||
| 148 | if arr[3] != "accepted" { | ||
| 149 | t.Errorf("Message = %v, want accepted", arr[3]) | ||
| 150 | } | ||
| 151 | } | ||
| 152 | |||
| 153 | func TestEOSEEnvelopeMarshalJSON(t *testing.T) { | ||
| 154 | env := EOSEEnvelope{SubscriptionID: "sub1"} | ||
| 155 | data, err := env.MarshalJSON() | ||
| 156 | if err != nil { | ||
| 157 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 158 | } | ||
| 159 | |||
| 160 | var arr []interface{} | ||
| 161 | if err := json.Unmarshal(data, &arr); err != nil { | ||
| 162 | t.Fatalf("Invalid JSON: %v", err) | ||
| 163 | } | ||
| 164 | |||
| 165 | if len(arr) != 2 { | ||
| 166 | t.Errorf("Array length = %d, want 2", len(arr)) | ||
| 167 | } | ||
| 168 | if arr[0] != "EOSE" { | ||
| 169 | t.Errorf("Label = %v, want EOSE", arr[0]) | ||
| 170 | } | ||
| 171 | } | ||
| 172 | |||
| 173 | func TestClosedEnvelopeMarshalJSON(t *testing.T) { | ||
| 174 | env := ClosedEnvelope{ | ||
| 175 | SubscriptionID: "sub1", | ||
| 176 | Message: "rate limited", | ||
| 177 | } | ||
| 178 | |||
| 179 | data, err := env.MarshalJSON() | ||
| 180 | if err != nil { | ||
| 181 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 182 | } | ||
| 183 | |||
| 184 | var arr []interface{} | ||
| 185 | if err := json.Unmarshal(data, &arr); err != nil { | ||
| 186 | t.Fatalf("Invalid JSON: %v", err) | ||
| 187 | } | ||
| 188 | |||
| 189 | if len(arr) != 3 { | ||
| 190 | t.Errorf("Array length = %d, want 3", len(arr)) | ||
| 191 | } | ||
| 192 | if arr[0] != "CLOSED" { | ||
| 193 | t.Errorf("Label = %v, want CLOSED", arr[0]) | ||
| 194 | } | ||
| 195 | } | ||
| 196 | |||
| 197 | func TestNoticeEnvelopeMarshalJSON(t *testing.T) { | ||
| 198 | env := NoticeEnvelope{Message: "error: rate limited"} | ||
| 199 | data, err := env.MarshalJSON() | ||
| 200 | if err != nil { | ||
| 201 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 202 | } | ||
| 203 | |||
| 204 | var arr []interface{} | ||
| 205 | if err := json.Unmarshal(data, &arr); err != nil { | ||
| 206 | t.Fatalf("Invalid JSON: %v", err) | ||
| 207 | } | ||
| 208 | |||
| 209 | if len(arr) != 2 { | ||
| 210 | t.Errorf("Array length = %d, want 2", len(arr)) | ||
| 211 | } | ||
| 212 | if arr[0] != "NOTICE" { | ||
| 213 | t.Errorf("Label = %v, want NOTICE", arr[0]) | ||
| 214 | } | ||
| 215 | } | ||
| 216 | |||
| 217 | func TestParseEnvelopeEvent(t *testing.T) { | ||
| 218 | t.Run("client to relay", func(t *testing.T) { | ||
| 219 | data := `["EVENT",{"id":"abc123","pubkey":"def456","created_at":1704067200,"kind":1,"tags":[],"content":"Hello","sig":"sig789"}]` | ||
| 220 | env, err := ParseEnvelope([]byte(data)) | ||
| 221 | if err != nil { | ||
| 222 | t.Fatalf("ParseEnvelope() error = %v", err) | ||
| 223 | } | ||
| 224 | |||
| 225 | eventEnv, ok := env.(*EventEnvelope) | ||
| 226 | if !ok { | ||
| 227 | t.Fatalf("Expected *EventEnvelope, got %T", env) | ||
| 228 | } | ||
| 229 | |||
| 230 | if eventEnv.SubscriptionID != "" { | ||
| 231 | t.Errorf("SubscriptionID = %s, want empty", eventEnv.SubscriptionID) | ||
| 232 | } | ||
| 233 | if eventEnv.Event.ID != "abc123" { | ||
| 234 | t.Errorf("Event.ID = %s, want abc123", eventEnv.Event.ID) | ||
| 235 | } | ||
| 236 | }) | ||
| 237 | |||
| 238 | t.Run("relay to client", func(t *testing.T) { | ||
| 239 | data := `["EVENT","sub1",{"id":"abc123","pubkey":"def456","created_at":1704067200,"kind":1,"tags":[],"content":"Hello","sig":"sig789"}]` | ||
| 240 | env, err := ParseEnvelope([]byte(data)) | ||
| 241 | if err != nil { | ||
| 242 | t.Fatalf("ParseEnvelope() error = %v", err) | ||
| 243 | } | ||
| 244 | |||
| 245 | eventEnv, ok := env.(*EventEnvelope) | ||
| 246 | if !ok { | ||
| 247 | t.Fatalf("Expected *EventEnvelope, got %T", env) | ||
| 248 | } | ||
| 249 | |||
| 250 | if eventEnv.SubscriptionID != "sub1" { | ||
| 251 | t.Errorf("SubscriptionID = %s, want sub1", eventEnv.SubscriptionID) | ||
| 252 | } | ||
| 253 | }) | ||
| 254 | } | ||
| 255 | |||
| 256 | func TestParseEnvelopeReq(t *testing.T) { | ||
| 257 | data := `["REQ","sub1",{"kinds":[1]},{"authors":["abc123"]}]` | ||
| 258 | env, err := ParseEnvelope([]byte(data)) | ||
| 259 | if err != nil { | ||
| 260 | t.Fatalf("ParseEnvelope() error = %v", err) | ||
| 261 | } | ||
| 262 | |||
| 263 | reqEnv, ok := env.(*ReqEnvelope) | ||
| 264 | if !ok { | ||
| 265 | t.Fatalf("Expected *ReqEnvelope, got %T", env) | ||
| 266 | } | ||
| 267 | |||
| 268 | if reqEnv.SubscriptionID != "sub1" { | ||
| 269 | t.Errorf("SubscriptionID = %s, want sub1", reqEnv.SubscriptionID) | ||
| 270 | } | ||
| 271 | if len(reqEnv.Filters) != 2 { | ||
| 272 | t.Errorf("Filters length = %d, want 2", len(reqEnv.Filters)) | ||
| 273 | } | ||
| 274 | } | ||
| 275 | |||
| 276 | func TestParseEnvelopeClose(t *testing.T) { | ||
| 277 | data := `["CLOSE","sub1"]` | ||
| 278 | env, err := ParseEnvelope([]byte(data)) | ||
| 279 | if err != nil { | ||
| 280 | t.Fatalf("ParseEnvelope() error = %v", err) | ||
| 281 | } | ||
| 282 | |||
| 283 | closeEnv, ok := env.(*CloseEnvelope) | ||
| 284 | if !ok { | ||
| 285 | t.Fatalf("Expected *CloseEnvelope, got %T", env) | ||
| 286 | } | ||
| 287 | |||
| 288 | if closeEnv.SubscriptionID != "sub1" { | ||
| 289 | t.Errorf("SubscriptionID = %s, want sub1", closeEnv.SubscriptionID) | ||
| 290 | } | ||
| 291 | } | ||
| 292 | |||
| 293 | func TestParseEnvelopeOK(t *testing.T) { | ||
| 294 | data := `["OK","event123",true,"accepted"]` | ||
| 295 | env, err := ParseEnvelope([]byte(data)) | ||
| 296 | if err != nil { | ||
| 297 | t.Fatalf("ParseEnvelope() error = %v", err) | ||
| 298 | } | ||
| 299 | |||
| 300 | okEnv, ok := env.(*OKEnvelope) | ||
| 301 | if !ok { | ||
| 302 | t.Fatalf("Expected *OKEnvelope, got %T", env) | ||
| 303 | } | ||
| 304 | |||
| 305 | if okEnv.EventID != "event123" { | ||
| 306 | t.Errorf("EventID = %s, want event123", okEnv.EventID) | ||
| 307 | } | ||
| 308 | if !okEnv.OK { | ||
| 309 | t.Error("OK = false, want true") | ||
| 310 | } | ||
| 311 | if okEnv.Message != "accepted" { | ||
| 312 | t.Errorf("Message = %s, want accepted", okEnv.Message) | ||
| 313 | } | ||
| 314 | } | ||
| 315 | |||
| 316 | func TestParseEnvelopeEOSE(t *testing.T) { | ||
| 317 | data := `["EOSE","sub1"]` | ||
| 318 | env, err := ParseEnvelope([]byte(data)) | ||
| 319 | if err != nil { | ||
| 320 | t.Fatalf("ParseEnvelope() error = %v", err) | ||
| 321 | } | ||
| 322 | |||
| 323 | eoseEnv, ok := env.(*EOSEEnvelope) | ||
| 324 | if !ok { | ||
| 325 | t.Fatalf("Expected *EOSEEnvelope, got %T", env) | ||
| 326 | } | ||
| 327 | |||
| 328 | if eoseEnv.SubscriptionID != "sub1" { | ||
| 329 | t.Errorf("SubscriptionID = %s, want sub1", eoseEnv.SubscriptionID) | ||
| 330 | } | ||
| 331 | } | ||
| 332 | |||
| 333 | func TestParseEnvelopeClosed(t *testing.T) { | ||
| 334 | data := `["CLOSED","sub1","rate limited"]` | ||
| 335 | env, err := ParseEnvelope([]byte(data)) | ||
| 336 | if err != nil { | ||
| 337 | t.Fatalf("ParseEnvelope() error = %v", err) | ||
| 338 | } | ||
| 339 | |||
| 340 | closedEnv, ok := env.(*ClosedEnvelope) | ||
| 341 | if !ok { | ||
| 342 | t.Fatalf("Expected *ClosedEnvelope, got %T", env) | ||
| 343 | } | ||
| 344 | |||
| 345 | if closedEnv.SubscriptionID != "sub1" { | ||
| 346 | t.Errorf("SubscriptionID = %s, want sub1", closedEnv.SubscriptionID) | ||
| 347 | } | ||
| 348 | if closedEnv.Message != "rate limited" { | ||
| 349 | t.Errorf("Message = %s, want rate limited", closedEnv.Message) | ||
| 350 | } | ||
| 351 | } | ||
| 352 | |||
| 353 | func TestParseEnvelopeNotice(t *testing.T) { | ||
| 354 | data := `["NOTICE","error: rate limited"]` | ||
| 355 | env, err := ParseEnvelope([]byte(data)) | ||
| 356 | if err != nil { | ||
| 357 | t.Fatalf("ParseEnvelope() error = %v", err) | ||
| 358 | } | ||
| 359 | |||
| 360 | noticeEnv, ok := env.(*NoticeEnvelope) | ||
| 361 | if !ok { | ||
| 362 | t.Fatalf("Expected *NoticeEnvelope, got %T", env) | ||
| 363 | } | ||
| 364 | |||
| 365 | if noticeEnv.Message != "error: rate limited" { | ||
| 366 | t.Errorf("Message = %s, want 'error: rate limited'", noticeEnv.Message) | ||
| 367 | } | ||
| 368 | } | ||
| 369 | |||
| 370 | func TestParseEnvelopeErrors(t *testing.T) { | ||
| 371 | tests := []struct { | ||
| 372 | name string | ||
| 373 | data string | ||
| 374 | }{ | ||
| 375 | {"invalid json", "not json"}, | ||
| 376 | {"not array", `{"type":"EVENT"}`}, | ||
| 377 | {"empty array", `[]`}, | ||
| 378 | {"single element", `["EVENT"]`}, | ||
| 379 | {"unknown type", `["UNKNOWN","data"]`}, | ||
| 380 | {"invalid event length", `["EVENT","a","b","c"]`}, | ||
| 381 | {"invalid ok length", `["OK","id",true]`}, | ||
| 382 | {"invalid eose length", `["EOSE"]`}, | ||
| 383 | } | ||
| 384 | |||
| 385 | for _, tt := range tests { | ||
| 386 | t.Run(tt.name, func(t *testing.T) { | ||
| 387 | _, err := ParseEnvelope([]byte(tt.data)) | ||
| 388 | if err == nil { | ||
| 389 | t.Error("ParseEnvelope() expected error, got nil") | ||
| 390 | } | ||
| 391 | }) | ||
| 392 | } | ||
| 393 | } | ||
| 394 | |||
| 395 | func TestEnvelopeLabel(t *testing.T) { | ||
| 396 | tests := []struct { | ||
| 397 | env Envelope | ||
| 398 | label string | ||
| 399 | }{ | ||
| 400 | {EventEnvelope{}, "EVENT"}, | ||
| 401 | {ReqEnvelope{}, "REQ"}, | ||
| 402 | {CloseEnvelope{}, "CLOSE"}, | ||
| 403 | {OKEnvelope{}, "OK"}, | ||
| 404 | {EOSEEnvelope{}, "EOSE"}, | ||
| 405 | {ClosedEnvelope{}, "CLOSED"}, | ||
| 406 | {NoticeEnvelope{}, "NOTICE"}, | ||
| 407 | } | ||
| 408 | |||
| 409 | for _, tt := range tests { | ||
| 410 | t.Run(tt.label, func(t *testing.T) { | ||
| 411 | if got := tt.env.Label(); got != tt.label { | ||
| 412 | t.Errorf("Label() = %s, want %s", got, tt.label) | ||
| 413 | } | ||
| 414 | }) | ||
| 415 | } | ||
| 416 | } | ||
diff --git a/internal/nostr/event.go b/internal/nostr/event.go new file mode 100644 index 0000000..a8156bb --- /dev/null +++ b/internal/nostr/event.go | |||
| @@ -0,0 +1,72 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "crypto/sha256" | ||
| 5 | "encoding/hex" | ||
| 6 | "encoding/json" | ||
| 7 | "fmt" | ||
| 8 | ) | ||
| 9 | |||
| 10 | // Event represents a Nostr event as defined in NIP-01. | ||
| 11 | type Event struct { | ||
| 12 | ID string `json:"id"` | ||
| 13 | PubKey string `json:"pubkey"` | ||
| 14 | CreatedAt int64 `json:"created_at"` | ||
| 15 | Kind int `json:"kind"` | ||
| 16 | Tags Tags `json:"tags"` | ||
| 17 | Content string `json:"content"` | ||
| 18 | Sig string `json:"sig"` | ||
| 19 | } | ||
| 20 | |||
| 21 | // Serialize returns the canonical JSON serialization of the event for ID computation. | ||
| 22 | // Format: [0, "pubkey", created_at, kind, tags, "content"] | ||
| 23 | func (e *Event) Serialize() []byte { | ||
| 24 | // Use json.Marshal for proper escaping of content and tags | ||
| 25 | arr := []interface{}{ | ||
| 26 | 0, | ||
| 27 | e.PubKey, | ||
| 28 | e.CreatedAt, | ||
| 29 | e.Kind, | ||
| 30 | e.Tags, | ||
| 31 | e.Content, | ||
| 32 | } | ||
| 33 | data, _ := json.Marshal(arr) | ||
| 34 | return data | ||
| 35 | } | ||
| 36 | |||
| 37 | // ComputeID calculates the SHA256 hash of the serialized event. | ||
| 38 | // Returns the 64-character hex-encoded ID. | ||
| 39 | func (e *Event) ComputeID() string { | ||
| 40 | serialized := e.Serialize() | ||
| 41 | hash := sha256.Sum256(serialized) | ||
| 42 | return hex.EncodeToString(hash[:]) | ||
| 43 | } | ||
| 44 | |||
| 45 | // SetID computes and sets the event ID. | ||
| 46 | func (e *Event) SetID() { | ||
| 47 | e.ID = e.ComputeID() | ||
| 48 | } | ||
| 49 | |||
| 50 | // CheckID verifies that the event ID matches the computed ID. | ||
| 51 | func (e *Event) CheckID() bool { | ||
| 52 | return e.ID == e.ComputeID() | ||
| 53 | } | ||
| 54 | |||
| 55 | // MarshalJSON implements json.Marshaler with empty tags as [] instead of null. | ||
| 56 | func (e Event) MarshalJSON() ([]byte, error) { | ||
| 57 | type eventAlias Event | ||
| 58 | ea := eventAlias(e) | ||
| 59 | if ea.Tags == nil { | ||
| 60 | ea.Tags = Tags{} | ||
| 61 | } | ||
| 62 | return json.Marshal(ea) | ||
| 63 | } | ||
| 64 | |||
| 65 | // String returns a JSON representation of the event for debugging. | ||
| 66 | func (e *Event) String() string { | ||
| 67 | data, err := json.MarshalIndent(e, "", " ") | ||
| 68 | if err != nil { | ||
| 69 | return fmt.Sprintf("<Event error: %v>", err) | ||
| 70 | } | ||
| 71 | return string(data) | ||
| 72 | } | ||
diff --git a/internal/nostr/event_test.go b/internal/nostr/event_test.go new file mode 100644 index 0000000..eff4103 --- /dev/null +++ b/internal/nostr/event_test.go | |||
| @@ -0,0 +1,194 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "encoding/json" | ||
| 5 | "testing" | ||
| 6 | ) | ||
| 7 | |||
| 8 | func TestEventSerialize(t *testing.T) { | ||
| 9 | event := &Event{ | ||
| 10 | PubKey: "7e7e9c42a91bfef19fa929e5fda1b72e0ebc1a4c1141673e2794234d86addf4e", | ||
| 11 | CreatedAt: 1704067200, | ||
| 12 | Kind: 1, | ||
| 13 | Tags: Tags{{"e", "abc123"}, {"p", "def456"}}, | ||
| 14 | Content: "Hello, Nostr!", | ||
| 15 | } | ||
| 16 | |||
| 17 | serialized := event.Serialize() | ||
| 18 | |||
| 19 | // Parse the JSON to verify structure | ||
| 20 | var arr []interface{} | ||
| 21 | if err := json.Unmarshal(serialized, &arr); err != nil { | ||
| 22 | t.Fatalf("Serialize() produced invalid JSON: %v", err) | ||
| 23 | } | ||
| 24 | |||
| 25 | if len(arr) != 6 { | ||
| 26 | t.Fatalf("Serialized array has %d elements, want 6", len(arr)) | ||
| 27 | } | ||
| 28 | |||
| 29 | // Check each element | ||
| 30 | if arr[0].(float64) != 0 { | ||
| 31 | t.Errorf("arr[0] = %v, want 0", arr[0]) | ||
| 32 | } | ||
| 33 | if arr[1].(string) != event.PubKey { | ||
| 34 | t.Errorf("arr[1] = %v, want %s", arr[1], event.PubKey) | ||
| 35 | } | ||
| 36 | if int64(arr[2].(float64)) != event.CreatedAt { | ||
| 37 | t.Errorf("arr[2] = %v, want %d", arr[2], event.CreatedAt) | ||
| 38 | } | ||
| 39 | if int(arr[3].(float64)) != event.Kind { | ||
| 40 | t.Errorf("arr[3] = %v, want %d", arr[3], event.Kind) | ||
| 41 | } | ||
| 42 | if arr[5].(string) != event.Content { | ||
| 43 | t.Errorf("arr[5] = %v, want %s", arr[5], event.Content) | ||
| 44 | } | ||
| 45 | } | ||
| 46 | |||
| 47 | func TestEventComputeID(t *testing.T) { | ||
| 48 | // Test with a known event (you can verify with other implementations) | ||
| 49 | event := &Event{ | ||
| 50 | PubKey: "7e7e9c42a91bfef19fa929e5fda1b72e0ebc1a4c1141673e2794234d86addf4e", | ||
| 51 | CreatedAt: 1704067200, | ||
| 52 | Kind: 1, | ||
| 53 | Tags: Tags{}, | ||
| 54 | Content: "Hello, Nostr!", | ||
| 55 | } | ||
| 56 | |||
| 57 | id := event.ComputeID() | ||
| 58 | |||
| 59 | // ID should be 64 hex characters | ||
| 60 | if len(id) != 64 { | ||
| 61 | t.Errorf("ComputeID() returned ID of length %d, want 64", len(id)) | ||
| 62 | } | ||
| 63 | |||
| 64 | // Verify it's valid hex | ||
| 65 | for _, c := range id { | ||
| 66 | if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { | ||
| 67 | t.Errorf("ComputeID() returned invalid hex character: %c", c) | ||
| 68 | } | ||
| 69 | } | ||
| 70 | |||
| 71 | // Verify consistency | ||
| 72 | id2 := event.ComputeID() | ||
| 73 | if id != id2 { | ||
| 74 | t.Errorf("ComputeID() is not consistent: %s != %s", id, id2) | ||
| 75 | } | ||
| 76 | } | ||
| 77 | |||
| 78 | func TestEventSetID(t *testing.T) { | ||
| 79 | event := &Event{ | ||
| 80 | PubKey: "7e7e9c42a91bfef19fa929e5fda1b72e0ebc1a4c1141673e2794234d86addf4e", | ||
| 81 | CreatedAt: 1704067200, | ||
| 82 | Kind: 1, | ||
| 83 | Tags: Tags{}, | ||
| 84 | Content: "Test", | ||
| 85 | } | ||
| 86 | |||
| 87 | event.SetID() | ||
| 88 | if event.ID == "" { | ||
| 89 | t.Error("SetID() did not set ID") | ||
| 90 | } | ||
| 91 | if !event.CheckID() { | ||
| 92 | t.Error("CheckID() returned false after SetID()") | ||
| 93 | } | ||
| 94 | } | ||
| 95 | |||
| 96 | func TestEventCheckID(t *testing.T) { | ||
| 97 | event := &Event{ | ||
| 98 | PubKey: "7e7e9c42a91bfef19fa929e5fda1b72e0ebc1a4c1141673e2794234d86addf4e", | ||
| 99 | CreatedAt: 1704067200, | ||
| 100 | Kind: 1, | ||
| 101 | Tags: Tags{}, | ||
| 102 | Content: "Test", | ||
| 103 | } | ||
| 104 | |||
| 105 | event.SetID() | ||
| 106 | |||
| 107 | if !event.CheckID() { | ||
| 108 | t.Error("CheckID() returned false for valid ID") | ||
| 109 | } | ||
| 110 | |||
| 111 | // Corrupt the ID | ||
| 112 | event.ID = "0000000000000000000000000000000000000000000000000000000000000000" | ||
| 113 | if event.CheckID() { | ||
| 114 | t.Error("CheckID() returned true for invalid ID") | ||
| 115 | } | ||
| 116 | } | ||
| 117 | |||
| 118 | func TestEventMarshalJSON(t *testing.T) { | ||
| 119 | event := Event{ | ||
| 120 | ID: "abc123", | ||
| 121 | PubKey: "def456", | ||
| 122 | CreatedAt: 1704067200, | ||
| 123 | Kind: 1, | ||
| 124 | Tags: nil, // nil tags | ||
| 125 | Content: "Test", | ||
| 126 | Sig: "sig789", | ||
| 127 | } | ||
| 128 | |||
| 129 | data, err := json.Marshal(event) | ||
| 130 | if err != nil { | ||
| 131 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 132 | } | ||
| 133 | |||
| 134 | // Verify tags is [] not null | ||
| 135 | var m map[string]interface{} | ||
| 136 | if err := json.Unmarshal(data, &m); err != nil { | ||
| 137 | t.Fatalf("Failed to unmarshal: %v", err) | ||
| 138 | } | ||
| 139 | |||
| 140 | tags, ok := m["tags"] | ||
| 141 | if !ok { | ||
| 142 | t.Error("tags field missing from JSON") | ||
| 143 | } | ||
| 144 | if tags == nil { | ||
| 145 | t.Error("tags is null, want []") | ||
| 146 | } | ||
| 147 | if arr, ok := tags.([]interface{}); !ok || len(arr) != 0 { | ||
| 148 | t.Errorf("tags = %v, want []", tags) | ||
| 149 | } | ||
| 150 | } | ||
| 151 | |||
| 152 | func TestEventJSONRoundTrip(t *testing.T) { | ||
| 153 | original := Event{ | ||
| 154 | ID: "abc123def456", | ||
| 155 | PubKey: "pubkey123", | ||
| 156 | CreatedAt: 1704067200, | ||
| 157 | Kind: 1, | ||
| 158 | Tags: Tags{{"e", "event1"}, {"p", "pubkey1", "relay"}}, | ||
| 159 | Content: "Hello with \"quotes\" and \n newlines", | ||
| 160 | Sig: "signature123", | ||
| 161 | } | ||
| 162 | |||
| 163 | data, err := json.Marshal(original) | ||
| 164 | if err != nil { | ||
| 165 | t.Fatalf("Marshal error: %v", err) | ||
| 166 | } | ||
| 167 | |||
| 168 | var decoded Event | ||
| 169 | if err := json.Unmarshal(data, &decoded); err != nil { | ||
| 170 | t.Fatalf("Unmarshal error: %v", err) | ||
| 171 | } | ||
| 172 | |||
| 173 | if decoded.ID != original.ID { | ||
| 174 | t.Errorf("ID mismatch: %s != %s", decoded.ID, original.ID) | ||
| 175 | } | ||
| 176 | if decoded.PubKey != original.PubKey { | ||
| 177 | t.Errorf("PubKey mismatch: %s != %s", decoded.PubKey, original.PubKey) | ||
| 178 | } | ||
| 179 | if decoded.CreatedAt != original.CreatedAt { | ||
| 180 | t.Errorf("CreatedAt mismatch: %d != %d", decoded.CreatedAt, original.CreatedAt) | ||
| 181 | } | ||
| 182 | if decoded.Kind != original.Kind { | ||
| 183 | t.Errorf("Kind mismatch: %d != %d", decoded.Kind, original.Kind) | ||
| 184 | } | ||
| 185 | if decoded.Content != original.Content { | ||
| 186 | t.Errorf("Content mismatch: %s != %s", decoded.Content, original.Content) | ||
| 187 | } | ||
| 188 | if decoded.Sig != original.Sig { | ||
| 189 | t.Errorf("Sig mismatch: %s != %s", decoded.Sig, original.Sig) | ||
| 190 | } | ||
| 191 | if len(decoded.Tags) != len(original.Tags) { | ||
| 192 | t.Errorf("Tags length mismatch: %d != %d", len(decoded.Tags), len(original.Tags)) | ||
| 193 | } | ||
| 194 | } | ||
diff --git a/internal/nostr/example_test.go b/internal/nostr/example_test.go new file mode 100644 index 0000000..80acd21 --- /dev/null +++ b/internal/nostr/example_test.go | |||
| @@ -0,0 +1,85 @@ | |||
| 1 | package nostr_test | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "context" | ||
| 5 | "fmt" | ||
| 6 | "time" | ||
| 7 | |||
| 8 | "northwest.io/nostr-grpc/internal/nostr" | ||
| 9 | ) | ||
| 10 | |||
| 11 | // Example_basic demonstrates basic usage of the nostr library. | ||
| 12 | func Example_basic() { | ||
| 13 | // Generate a new key pair | ||
| 14 | key, err := nostr.GenerateKey() | ||
| 15 | if err != nil { | ||
| 16 | fmt.Printf("Failed to generate key: %v\n", err) | ||
| 17 | return | ||
| 18 | } | ||
| 19 | |||
| 20 | fmt.Printf("Public key (hex): %s...\n", key.Public()[:16]) | ||
| 21 | fmt.Printf("Public key (npub): %s...\n", key.Npub()[:20]) | ||
| 22 | |||
| 23 | // Create an event | ||
| 24 | event := &nostr.Event{ | ||
| 25 | CreatedAt: time.Now().Unix(), | ||
| 26 | Kind: nostr.KindTextNote, | ||
| 27 | Tags: nostr.Tags{{"t", "test"}}, | ||
| 28 | Content: "Hello from nostr-go!", | ||
| 29 | } | ||
| 30 | |||
| 31 | // Sign the event | ||
| 32 | if err := key.Sign(event); err != nil { | ||
| 33 | fmt.Printf("Failed to sign event: %v\n", err) | ||
| 34 | return | ||
| 35 | } | ||
| 36 | |||
| 37 | // Verify the signature | ||
| 38 | if event.Verify() { | ||
| 39 | fmt.Println("Event signature verified!") | ||
| 40 | } | ||
| 41 | |||
| 42 | // Create a filter to match our event | ||
| 43 | filter := nostr.Filter{ | ||
| 44 | Kinds: []int{nostr.KindTextNote}, | ||
| 45 | Authors: []string{key.Public()[:8]}, // Prefix matching | ||
| 46 | } | ||
| 47 | |||
| 48 | if filter.Matches(event) { | ||
| 49 | fmt.Println("Filter matches the event!") | ||
| 50 | } | ||
| 51 | } | ||
| 52 | |||
| 53 | // ExampleRelay demonstrates connecting to a relay (requires network). | ||
| 54 | // This is a documentation example - run with: go test -v -run ExampleRelay | ||
| 55 | func ExampleRelay() { | ||
| 56 | ctx := context.Background() | ||
| 57 | |||
| 58 | // Connect to a public relay | ||
| 59 | relay, err := nostr.Connect(ctx, "wss://relay.damus.io") | ||
| 60 | if err != nil { | ||
| 61 | fmt.Printf("Failed to connect: %v\n", err) | ||
| 62 | return | ||
| 63 | } | ||
| 64 | defer relay.Close() | ||
| 65 | |||
| 66 | fmt.Println("Connected to relay!") | ||
| 67 | |||
| 68 | ctx, cancel := context.WithTimeout(ctx, 10*time.Second) | ||
| 69 | defer cancel() | ||
| 70 | |||
| 71 | // Fetch recent text notes (closes on EOSE) | ||
| 72 | since := time.Now().Add(-1 * time.Hour).Unix() | ||
| 73 | sub := relay.Fetch(ctx, nostr.Filter{ | ||
| 74 | Kinds: []int{nostr.KindTextNote}, | ||
| 75 | Since: &since, | ||
| 76 | Limit: 5, | ||
| 77 | }) | ||
| 78 | |||
| 79 | eventCount := 0 | ||
| 80 | for event := range sub.Events { | ||
| 81 | eventCount++ | ||
| 82 | fmt.Printf("Received event from %s...\n", event.PubKey[:8]) | ||
| 83 | } | ||
| 84 | fmt.Printf("Received %d events\n", eventCount) | ||
| 85 | } | ||
diff --git a/internal/nostr/filter.go b/internal/nostr/filter.go new file mode 100644 index 0000000..dde04a5 --- /dev/null +++ b/internal/nostr/filter.go | |||
| @@ -0,0 +1,224 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "encoding/json" | ||
| 5 | "strings" | ||
| 6 | ) | ||
| 7 | |||
| 8 | // Filter represents a subscription filter as defined in NIP-01. | ||
| 9 | type Filter struct { | ||
| 10 | IDs []string `json:"ids,omitempty"` | ||
| 11 | Kinds []int `json:"kinds,omitempty"` | ||
| 12 | Authors []string `json:"authors,omitempty"` | ||
| 13 | Tags map[string][]string `json:"-"` // Custom marshaling for #e, #p, etc. | ||
| 14 | Since *int64 `json:"since,omitempty"` | ||
| 15 | Until *int64 `json:"until,omitempty"` | ||
| 16 | Limit int `json:"limit,omitempty"` | ||
| 17 | } | ||
| 18 | |||
| 19 | // MarshalJSON implements json.Marshaler for Filter. | ||
| 20 | // Converts Tags map to #e, #p format. | ||
| 21 | func (f Filter) MarshalJSON() ([]byte, error) { | ||
| 22 | // Create a map for custom marshaling | ||
| 23 | m := make(map[string]interface{}) | ||
| 24 | |||
| 25 | if len(f.IDs) > 0 { | ||
| 26 | m["ids"] = f.IDs | ||
| 27 | } | ||
| 28 | if len(f.Kinds) > 0 { | ||
| 29 | m["kinds"] = f.Kinds | ||
| 30 | } | ||
| 31 | if len(f.Authors) > 0 { | ||
| 32 | m["authors"] = f.Authors | ||
| 33 | } | ||
| 34 | if f.Since != nil { | ||
| 35 | m["since"] = *f.Since | ||
| 36 | } | ||
| 37 | if f.Until != nil { | ||
| 38 | m["until"] = *f.Until | ||
| 39 | } | ||
| 40 | if f.Limit > 0 { | ||
| 41 | m["limit"] = f.Limit | ||
| 42 | } | ||
| 43 | |||
| 44 | // Add tag filters with # prefix | ||
| 45 | for key, values := range f.Tags { | ||
| 46 | if len(values) > 0 { | ||
| 47 | m["#"+key] = values | ||
| 48 | } | ||
| 49 | } | ||
| 50 | |||
| 51 | return json.Marshal(m) | ||
| 52 | } | ||
| 53 | |||
| 54 | // UnmarshalJSON implements json.Unmarshaler for Filter. | ||
| 55 | // Extracts #e, #p format into Tags map. | ||
| 56 | func (f *Filter) UnmarshalJSON(data []byte) error { | ||
| 57 | // First unmarshal into a raw map | ||
| 58 | var raw map[string]json.RawMessage | ||
| 59 | if err := json.Unmarshal(data, &raw); err != nil { | ||
| 60 | return err | ||
| 61 | } | ||
| 62 | |||
| 63 | // Extract known fields | ||
| 64 | if v, ok := raw["ids"]; ok { | ||
| 65 | if err := json.Unmarshal(v, &f.IDs); err != nil { | ||
| 66 | return err | ||
| 67 | } | ||
| 68 | } | ||
| 69 | if v, ok := raw["kinds"]; ok { | ||
| 70 | if err := json.Unmarshal(v, &f.Kinds); err != nil { | ||
| 71 | return err | ||
| 72 | } | ||
| 73 | } | ||
| 74 | if v, ok := raw["authors"]; ok { | ||
| 75 | if err := json.Unmarshal(v, &f.Authors); err != nil { | ||
| 76 | return err | ||
| 77 | } | ||
| 78 | } | ||
| 79 | if v, ok := raw["since"]; ok { | ||
| 80 | var since int64 | ||
| 81 | if err := json.Unmarshal(v, &since); err != nil { | ||
| 82 | return err | ||
| 83 | } | ||
| 84 | f.Since = &since | ||
| 85 | } | ||
| 86 | if v, ok := raw["until"]; ok { | ||
| 87 | var until int64 | ||
| 88 | if err := json.Unmarshal(v, &until); err != nil { | ||
| 89 | return err | ||
| 90 | } | ||
| 91 | f.Until = &until | ||
| 92 | } | ||
| 93 | if v, ok := raw["limit"]; ok { | ||
| 94 | if err := json.Unmarshal(v, &f.Limit); err != nil { | ||
| 95 | return err | ||
| 96 | } | ||
| 97 | } | ||
| 98 | |||
| 99 | // Extract tag filters (fields starting with #) | ||
| 100 | f.Tags = make(map[string][]string) | ||
| 101 | for key, value := range raw { | ||
| 102 | if strings.HasPrefix(key, "#") { | ||
| 103 | tagKey := strings.TrimPrefix(key, "#") | ||
| 104 | var values []string | ||
| 105 | if err := json.Unmarshal(value, &values); err != nil { | ||
| 106 | return err | ||
| 107 | } | ||
| 108 | f.Tags[tagKey] = values | ||
| 109 | } | ||
| 110 | } | ||
| 111 | |||
| 112 | return nil | ||
| 113 | } | ||
| 114 | |||
| 115 | // Matches checks if an event matches this filter. | ||
| 116 | func (f *Filter) Matches(event *Event) bool { | ||
| 117 | // Check IDs (prefix match) | ||
| 118 | if len(f.IDs) > 0 { | ||
| 119 | found := false | ||
| 120 | for _, id := range f.IDs { | ||
| 121 | if strings.HasPrefix(event.ID, id) { | ||
| 122 | found = true | ||
| 123 | break | ||
| 124 | } | ||
| 125 | } | ||
| 126 | if !found { | ||
| 127 | return false | ||
| 128 | } | ||
| 129 | } | ||
| 130 | |||
| 131 | // Check authors (prefix match) | ||
| 132 | if len(f.Authors) > 0 { | ||
| 133 | found := false | ||
| 134 | for _, author := range f.Authors { | ||
| 135 | if strings.HasPrefix(event.PubKey, author) { | ||
| 136 | found = true | ||
| 137 | break | ||
| 138 | } | ||
| 139 | } | ||
| 140 | if !found { | ||
| 141 | return false | ||
| 142 | } | ||
| 143 | } | ||
| 144 | |||
| 145 | // Check kinds | ||
| 146 | if len(f.Kinds) > 0 { | ||
| 147 | found := false | ||
| 148 | for _, kind := range f.Kinds { | ||
| 149 | if event.Kind == kind { | ||
| 150 | found = true | ||
| 151 | break | ||
| 152 | } | ||
| 153 | } | ||
| 154 | if !found { | ||
| 155 | return false | ||
| 156 | } | ||
| 157 | } | ||
| 158 | |||
| 159 | // Check since | ||
| 160 | if f.Since != nil && event.CreatedAt < *f.Since { | ||
| 161 | return false | ||
| 162 | } | ||
| 163 | |||
| 164 | // Check until | ||
| 165 | if f.Until != nil && event.CreatedAt > *f.Until { | ||
| 166 | return false | ||
| 167 | } | ||
| 168 | |||
| 169 | // Check tag filters | ||
| 170 | for tagKey, values := range f.Tags { | ||
| 171 | if len(values) == 0 { | ||
| 172 | continue | ||
| 173 | } | ||
| 174 | found := false | ||
| 175 | for _, val := range values { | ||
| 176 | if event.Tags.ContainsValue(tagKey, val) { | ||
| 177 | found = true | ||
| 178 | break | ||
| 179 | } | ||
| 180 | } | ||
| 181 | if !found { | ||
| 182 | return false | ||
| 183 | } | ||
| 184 | } | ||
| 185 | |||
| 186 | return true | ||
| 187 | } | ||
| 188 | |||
| 189 | // Clone creates a deep copy of the filter. | ||
| 190 | func (f *Filter) Clone() *Filter { | ||
| 191 | clone := &Filter{ | ||
| 192 | Limit: f.Limit, | ||
| 193 | } | ||
| 194 | |||
| 195 | if f.IDs != nil { | ||
| 196 | clone.IDs = make([]string, len(f.IDs)) | ||
| 197 | copy(clone.IDs, f.IDs) | ||
| 198 | } | ||
| 199 | if f.Kinds != nil { | ||
| 200 | clone.Kinds = make([]int, len(f.Kinds)) | ||
| 201 | copy(clone.Kinds, f.Kinds) | ||
| 202 | } | ||
| 203 | if f.Authors != nil { | ||
| 204 | clone.Authors = make([]string, len(f.Authors)) | ||
| 205 | copy(clone.Authors, f.Authors) | ||
| 206 | } | ||
| 207 | if f.Since != nil { | ||
| 208 | since := *f.Since | ||
| 209 | clone.Since = &since | ||
| 210 | } | ||
| 211 | if f.Until != nil { | ||
| 212 | until := *f.Until | ||
| 213 | clone.Until = &until | ||
| 214 | } | ||
| 215 | if f.Tags != nil { | ||
| 216 | clone.Tags = make(map[string][]string) | ||
| 217 | for k, v := range f.Tags { | ||
| 218 | clone.Tags[k] = make([]string, len(v)) | ||
| 219 | copy(clone.Tags[k], v) | ||
| 220 | } | ||
| 221 | } | ||
| 222 | |||
| 223 | return clone | ||
| 224 | } | ||
diff --git a/internal/nostr/filter_test.go b/internal/nostr/filter_test.go new file mode 100644 index 0000000..ebe2b1d --- /dev/null +++ b/internal/nostr/filter_test.go | |||
| @@ -0,0 +1,415 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "encoding/json" | ||
| 5 | "testing" | ||
| 6 | ) | ||
| 7 | |||
| 8 | func TestFilterMarshalJSON(t *testing.T) { | ||
| 9 | since := int64(1704067200) | ||
| 10 | until := int64(1704153600) | ||
| 11 | |||
| 12 | filter := Filter{ | ||
| 13 | IDs: []string{"abc123"}, | ||
| 14 | Kinds: []int{1, 7}, | ||
| 15 | Authors: []string{"def456"}, | ||
| 16 | Tags: map[string][]string{ | ||
| 17 | "e": {"event1", "event2"}, | ||
| 18 | "p": {"pubkey1"}, | ||
| 19 | }, | ||
| 20 | Since: &since, | ||
| 21 | Until: &until, | ||
| 22 | Limit: 100, | ||
| 23 | } | ||
| 24 | |||
| 25 | data, err := filter.MarshalJSON() | ||
| 26 | if err != nil { | ||
| 27 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 28 | } | ||
| 29 | |||
| 30 | // Parse and check structure | ||
| 31 | var m map[string]interface{} | ||
| 32 | if err := json.Unmarshal(data, &m); err != nil { | ||
| 33 | t.Fatalf("Failed to unmarshal: %v", err) | ||
| 34 | } | ||
| 35 | |||
| 36 | // Check regular fields | ||
| 37 | if _, ok := m["ids"]; !ok { | ||
| 38 | t.Error("ids field missing") | ||
| 39 | } | ||
| 40 | if _, ok := m["kinds"]; !ok { | ||
| 41 | t.Error("kinds field missing") | ||
| 42 | } | ||
| 43 | if _, ok := m["authors"]; !ok { | ||
| 44 | t.Error("authors field missing") | ||
| 45 | } | ||
| 46 | if _, ok := m["since"]; !ok { | ||
| 47 | t.Error("since field missing") | ||
| 48 | } | ||
| 49 | if _, ok := m["until"]; !ok { | ||
| 50 | t.Error("until field missing") | ||
| 51 | } | ||
| 52 | if _, ok := m["limit"]; !ok { | ||
| 53 | t.Error("limit field missing") | ||
| 54 | } | ||
| 55 | |||
| 56 | // Check tag filters with # prefix | ||
| 57 | if _, ok := m["#e"]; !ok { | ||
| 58 | t.Error("#e field missing") | ||
| 59 | } | ||
| 60 | if _, ok := m["#p"]; !ok { | ||
| 61 | t.Error("#p field missing") | ||
| 62 | } | ||
| 63 | } | ||
| 64 | |||
| 65 | func TestFilterMarshalJSONOmitsEmpty(t *testing.T) { | ||
| 66 | filter := Filter{ | ||
| 67 | Kinds: []int{1}, | ||
| 68 | } | ||
| 69 | |||
| 70 | data, err := filter.MarshalJSON() | ||
| 71 | if err != nil { | ||
| 72 | t.Fatalf("MarshalJSON() error = %v", err) | ||
| 73 | } | ||
| 74 | |||
| 75 | var m map[string]interface{} | ||
| 76 | if err := json.Unmarshal(data, &m); err != nil { | ||
| 77 | t.Fatalf("Failed to unmarshal: %v", err) | ||
| 78 | } | ||
| 79 | |||
| 80 | if _, ok := m["ids"]; ok { | ||
| 81 | t.Error("empty ids should be omitted") | ||
| 82 | } | ||
| 83 | if _, ok := m["authors"]; ok { | ||
| 84 | t.Error("empty authors should be omitted") | ||
| 85 | } | ||
| 86 | if _, ok := m["since"]; ok { | ||
| 87 | t.Error("nil since should be omitted") | ||
| 88 | } | ||
| 89 | if _, ok := m["until"]; ok { | ||
| 90 | t.Error("nil until should be omitted") | ||
| 91 | } | ||
| 92 | if _, ok := m["limit"]; ok { | ||
| 93 | t.Error("zero limit should be omitted") | ||
| 94 | } | ||
| 95 | } | ||
| 96 | |||
| 97 | func TestFilterUnmarshalJSON(t *testing.T) { | ||
| 98 | jsonData := `{ | ||
| 99 | "ids": ["abc123"], | ||
| 100 | "kinds": [1, 7], | ||
| 101 | "authors": ["def456"], | ||
| 102 | "#e": ["event1", "event2"], | ||
| 103 | "#p": ["pubkey1"], | ||
| 104 | "since": 1704067200, | ||
| 105 | "until": 1704153600, | ||
| 106 | "limit": 100 | ||
| 107 | }` | ||
| 108 | |||
| 109 | var filter Filter | ||
| 110 | if err := json.Unmarshal([]byte(jsonData), &filter); err != nil { | ||
| 111 | t.Fatalf("UnmarshalJSON() error = %v", err) | ||
| 112 | } | ||
| 113 | |||
| 114 | if len(filter.IDs) != 1 || filter.IDs[0] != "abc123" { | ||
| 115 | t.Errorf("IDs = %v, want [abc123]", filter.IDs) | ||
| 116 | } | ||
| 117 | if len(filter.Kinds) != 2 { | ||
| 118 | t.Errorf("Kinds length = %d, want 2", len(filter.Kinds)) | ||
| 119 | } | ||
| 120 | if len(filter.Authors) != 1 || filter.Authors[0] != "def456" { | ||
| 121 | t.Errorf("Authors = %v, want [def456]", filter.Authors) | ||
| 122 | } | ||
| 123 | if filter.Since == nil || *filter.Since != 1704067200 { | ||
| 124 | t.Errorf("Since = %v, want 1704067200", filter.Since) | ||
| 125 | } | ||
| 126 | if filter.Until == nil || *filter.Until != 1704153600 { | ||
| 127 | t.Errorf("Until = %v, want 1704153600", filter.Until) | ||
| 128 | } | ||
| 129 | if filter.Limit != 100 { | ||
| 130 | t.Errorf("Limit = %d, want 100", filter.Limit) | ||
| 131 | } | ||
| 132 | |||
| 133 | // Check tag filters | ||
| 134 | if len(filter.Tags["e"]) != 2 { | ||
| 135 | t.Errorf("Tags[e] length = %d, want 2", len(filter.Tags["e"])) | ||
| 136 | } | ||
| 137 | if len(filter.Tags["p"]) != 1 { | ||
| 138 | t.Errorf("Tags[p] length = %d, want 1", len(filter.Tags["p"])) | ||
| 139 | } | ||
| 140 | } | ||
| 141 | |||
| 142 | func TestFilterMatchesIDs(t *testing.T) { | ||
| 143 | filter := Filter{ | ||
| 144 | IDs: []string{"abc", "def456"}, | ||
| 145 | } | ||
| 146 | |||
| 147 | tests := []struct { | ||
| 148 | id string | ||
| 149 | want bool | ||
| 150 | }{ | ||
| 151 | {"abc123", true}, // matches prefix "abc" | ||
| 152 | {"abcdef", true}, // matches prefix "abc" | ||
| 153 | {"def456", true}, // exact match | ||
| 154 | {"def456xyz", true}, // matches prefix "def456" | ||
| 155 | {"xyz789", false}, // no match | ||
| 156 | {"ab", false}, // "ab" doesn't start with "abc" | ||
| 157 | } | ||
| 158 | |||
| 159 | for _, tt := range tests { | ||
| 160 | event := &Event{ID: tt.id} | ||
| 161 | if got := filter.Matches(event); got != tt.want { | ||
| 162 | t.Errorf("Matches() with ID %s = %v, want %v", tt.id, got, tt.want) | ||
| 163 | } | ||
| 164 | } | ||
| 165 | } | ||
| 166 | |||
| 167 | func TestFilterMatchesAuthors(t *testing.T) { | ||
| 168 | filter := Filter{ | ||
| 169 | Authors: []string{"pubkey1", "pubkey2"}, | ||
| 170 | } | ||
| 171 | |||
| 172 | tests := []struct { | ||
| 173 | pubkey string | ||
| 174 | want bool | ||
| 175 | }{ | ||
| 176 | {"pubkey1", true}, | ||
| 177 | {"pubkey1abc", true}, // Prefix match | ||
| 178 | {"pubkey2", true}, | ||
| 179 | {"pubkey3", false}, | ||
| 180 | } | ||
| 181 | |||
| 182 | for _, tt := range tests { | ||
| 183 | event := &Event{PubKey: tt.pubkey} | ||
| 184 | if got := filter.Matches(event); got != tt.want { | ||
| 185 | t.Errorf("Matches() with PubKey %s = %v, want %v", tt.pubkey, got, tt.want) | ||
| 186 | } | ||
| 187 | } | ||
| 188 | } | ||
| 189 | |||
| 190 | func TestFilterMatchesKinds(t *testing.T) { | ||
| 191 | filter := Filter{ | ||
| 192 | Kinds: []int{1, 7}, | ||
| 193 | } | ||
| 194 | |||
| 195 | tests := []struct { | ||
| 196 | kind int | ||
| 197 | want bool | ||
| 198 | }{ | ||
| 199 | {1, true}, | ||
| 200 | {7, true}, | ||
| 201 | {0, false}, | ||
| 202 | {4, false}, | ||
| 203 | } | ||
| 204 | |||
| 205 | for _, tt := range tests { | ||
| 206 | event := &Event{Kind: tt.kind} | ||
| 207 | if got := filter.Matches(event); got != tt.want { | ||
| 208 | t.Errorf("Matches() with Kind %d = %v, want %v", tt.kind, got, tt.want) | ||
| 209 | } | ||
| 210 | } | ||
| 211 | } | ||
| 212 | |||
| 213 | func TestFilterMatchesSince(t *testing.T) { | ||
| 214 | since := int64(1704067200) | ||
| 215 | filter := Filter{ | ||
| 216 | Since: &since, | ||
| 217 | } | ||
| 218 | |||
| 219 | tests := []struct { | ||
| 220 | createdAt int64 | ||
| 221 | want bool | ||
| 222 | }{ | ||
| 223 | {1704067200, true}, // Equal | ||
| 224 | {1704067201, true}, // After | ||
| 225 | {1704067199, false}, // Before | ||
| 226 | } | ||
| 227 | |||
| 228 | for _, tt := range tests { | ||
| 229 | event := &Event{CreatedAt: tt.createdAt} | ||
| 230 | if got := filter.Matches(event); got != tt.want { | ||
| 231 | t.Errorf("Matches() with CreatedAt %d = %v, want %v", tt.createdAt, got, tt.want) | ||
| 232 | } | ||
| 233 | } | ||
| 234 | } | ||
| 235 | |||
| 236 | func TestFilterMatchesUntil(t *testing.T) { | ||
| 237 | until := int64(1704067200) | ||
| 238 | filter := Filter{ | ||
| 239 | Until: &until, | ||
| 240 | } | ||
| 241 | |||
| 242 | tests := []struct { | ||
| 243 | createdAt int64 | ||
| 244 | want bool | ||
| 245 | }{ | ||
| 246 | {1704067200, true}, // Equal | ||
| 247 | {1704067199, true}, // Before | ||
| 248 | {1704067201, false}, // After | ||
| 249 | } | ||
| 250 | |||
| 251 | for _, tt := range tests { | ||
| 252 | event := &Event{CreatedAt: tt.createdAt} | ||
| 253 | if got := filter.Matches(event); got != tt.want { | ||
| 254 | t.Errorf("Matches() with CreatedAt %d = %v, want %v", tt.createdAt, got, tt.want) | ||
| 255 | } | ||
| 256 | } | ||
| 257 | } | ||
| 258 | |||
| 259 | func TestFilterMatchesTags(t *testing.T) { | ||
| 260 | filter := Filter{ | ||
| 261 | Tags: map[string][]string{ | ||
| 262 | "e": {"event1"}, | ||
| 263 | "p": {"pubkey1", "pubkey2"}, | ||
| 264 | }, | ||
| 265 | } | ||
| 266 | |||
| 267 | tests := []struct { | ||
| 268 | name string | ||
| 269 | tags Tags | ||
| 270 | want bool | ||
| 271 | }{ | ||
| 272 | { | ||
| 273 | name: "matches all", | ||
| 274 | tags: Tags{{"e", "event1"}, {"p", "pubkey1"}}, | ||
| 275 | want: true, | ||
| 276 | }, | ||
| 277 | { | ||
| 278 | name: "matches with different p", | ||
| 279 | tags: Tags{{"e", "event1"}, {"p", "pubkey2"}}, | ||
| 280 | want: true, | ||
| 281 | }, | ||
| 282 | { | ||
| 283 | name: "missing e tag", | ||
| 284 | tags: Tags{{"p", "pubkey1"}}, | ||
| 285 | want: false, | ||
| 286 | }, | ||
| 287 | { | ||
| 288 | name: "wrong e value", | ||
| 289 | tags: Tags{{"e", "event2"}, {"p", "pubkey1"}}, | ||
| 290 | want: false, | ||
| 291 | }, | ||
| 292 | { | ||
| 293 | name: "extra tags ok", | ||
| 294 | tags: Tags{{"e", "event1"}, {"p", "pubkey1"}, {"t", "test"}}, | ||
| 295 | want: true, | ||
| 296 | }, | ||
| 297 | } | ||
| 298 | |||
| 299 | for _, tt := range tests { | ||
| 300 | t.Run(tt.name, func(t *testing.T) { | ||
| 301 | event := &Event{Tags: tt.tags} | ||
| 302 | if got := filter.Matches(event); got != tt.want { | ||
| 303 | t.Errorf("Matches() = %v, want %v", got, tt.want) | ||
| 304 | } | ||
| 305 | }) | ||
| 306 | } | ||
| 307 | } | ||
| 308 | |||
| 309 | func TestFilterMatchesEmpty(t *testing.T) { | ||
| 310 | // Empty filter matches everything | ||
| 311 | filter := Filter{} | ||
| 312 | event := &Event{ | ||
| 313 | ID: "abc123", | ||
| 314 | PubKey: "pubkey1", | ||
| 315 | CreatedAt: 1704067200, | ||
| 316 | Kind: 1, | ||
| 317 | Tags: Tags{{"e", "event1"}}, | ||
| 318 | Content: "test", | ||
| 319 | } | ||
| 320 | |||
| 321 | if !filter.Matches(event) { | ||
| 322 | t.Error("Empty filter should match all events") | ||
| 323 | } | ||
| 324 | } | ||
| 325 | |||
| 326 | func TestFilterClone(t *testing.T) { | ||
| 327 | since := int64(1704067200) | ||
| 328 | until := int64(1704153600) | ||
| 329 | |||
| 330 | original := &Filter{ | ||
| 331 | IDs: []string{"id1", "id2"}, | ||
| 332 | Kinds: []int{1, 7}, | ||
| 333 | Authors: []string{"author1"}, | ||
| 334 | Tags: map[string][]string{ | ||
| 335 | "e": {"event1"}, | ||
| 336 | }, | ||
| 337 | Since: &since, | ||
| 338 | Until: &until, | ||
| 339 | Limit: 100, | ||
| 340 | } | ||
| 341 | |||
| 342 | clone := original.Clone() | ||
| 343 | |||
| 344 | // Modify original | ||
| 345 | original.IDs[0] = "modified" | ||
| 346 | original.Kinds[0] = 999 | ||
| 347 | original.Authors[0] = "modified" | ||
| 348 | original.Tags["e"][0] = "modified" | ||
| 349 | *original.Since = 0 | ||
| 350 | *original.Until = 0 | ||
| 351 | original.Limit = 0 | ||
| 352 | |||
| 353 | // Clone should be unchanged | ||
| 354 | if clone.IDs[0] != "id1" { | ||
| 355 | t.Error("Clone IDs was modified") | ||
| 356 | } | ||
| 357 | if clone.Kinds[0] != 1 { | ||
| 358 | t.Error("Clone Kinds was modified") | ||
| 359 | } | ||
| 360 | if clone.Authors[0] != "author1" { | ||
| 361 | t.Error("Clone Authors was modified") | ||
| 362 | } | ||
| 363 | if clone.Tags["e"][0] != "event1" { | ||
| 364 | t.Error("Clone Tags was modified") | ||
| 365 | } | ||
| 366 | if *clone.Since != 1704067200 { | ||
| 367 | t.Error("Clone Since was modified") | ||
| 368 | } | ||
| 369 | if *clone.Until != 1704153600 { | ||
| 370 | t.Error("Clone Until was modified") | ||
| 371 | } | ||
| 372 | if clone.Limit != 100 { | ||
| 373 | t.Error("Clone Limit was modified") | ||
| 374 | } | ||
| 375 | } | ||
| 376 | |||
| 377 | func TestFilterJSONRoundTrip(t *testing.T) { | ||
| 378 | since := int64(1704067200) | ||
| 379 | original := Filter{ | ||
| 380 | IDs: []string{"abc123"}, | ||
| 381 | Kinds: []int{1}, | ||
| 382 | Authors: []string{"def456"}, | ||
| 383 | Tags: map[string][]string{ | ||
| 384 | "e": {"event1"}, | ||
| 385 | }, | ||
| 386 | Since: &since, | ||
| 387 | Limit: 50, | ||
| 388 | } | ||
| 389 | |||
| 390 | data, err := json.Marshal(original) | ||
| 391 | if err != nil { | ||
| 392 | t.Fatalf("Marshal error: %v", err) | ||
| 393 | } | ||
| 394 | |||
| 395 | var decoded Filter | ||
| 396 | if err := json.Unmarshal(data, &decoded); err != nil { | ||
| 397 | t.Fatalf("Unmarshal error: %v", err) | ||
| 398 | } | ||
| 399 | |||
| 400 | if len(decoded.IDs) != 1 || decoded.IDs[0] != "abc123" { | ||
| 401 | t.Errorf("IDs mismatch") | ||
| 402 | } | ||
| 403 | if len(decoded.Kinds) != 1 || decoded.Kinds[0] != 1 { | ||
| 404 | t.Errorf("Kinds mismatch") | ||
| 405 | } | ||
| 406 | if len(decoded.Tags["e"]) != 1 || decoded.Tags["e"][0] != "event1" { | ||
| 407 | t.Errorf("Tags mismatch") | ||
| 408 | } | ||
| 409 | if decoded.Since == nil || *decoded.Since != since { | ||
| 410 | t.Errorf("Since mismatch") | ||
| 411 | } | ||
| 412 | if decoded.Limit != 50 { | ||
| 413 | t.Errorf("Limit mismatch") | ||
| 414 | } | ||
| 415 | } | ||
diff --git a/internal/nostr/keys.go b/internal/nostr/keys.go new file mode 100644 index 0000000..3a3fb9c --- /dev/null +++ b/internal/nostr/keys.go | |||
| @@ -0,0 +1,217 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "crypto/rand" | ||
| 5 | "encoding/hex" | ||
| 6 | "fmt" | ||
| 7 | "strings" | ||
| 8 | "time" | ||
| 9 | |||
| 10 | "github.com/btcsuite/btcd/btcec/v2" | ||
| 11 | "github.com/btcsuite/btcd/btcec/v2/schnorr" | ||
| 12 | ) | ||
| 13 | |||
| 14 | // Key represents a Nostr key, which may be a full private key or public-only. | ||
| 15 | // Use GenerateKey or ParseKey for private keys, ParsePublicKey for public-only. | ||
| 16 | type Key struct { | ||
| 17 | priv *btcec.PrivateKey // nil for public-only keys | ||
| 18 | pub *btcec.PublicKey // always set | ||
| 19 | } | ||
| 20 | |||
| 21 | // GenerateKey generates a new random private key. | ||
| 22 | func GenerateKey() (*Key, error) { | ||
| 23 | var keyBytes [32]byte | ||
| 24 | if _, err := rand.Read(keyBytes[:]); err != nil { | ||
| 25 | return nil, fmt.Errorf("failed to generate random bytes: %w", err) | ||
| 26 | } | ||
| 27 | |||
| 28 | priv, _ := btcec.PrivKeyFromBytes(keyBytes[:]) | ||
| 29 | return &Key{ | ||
| 30 | priv: priv, | ||
| 31 | pub: priv.PubKey(), | ||
| 32 | }, nil | ||
| 33 | } | ||
| 34 | |||
| 35 | // ParseKey parses a private key from hex or nsec (bech32) format. | ||
| 36 | func ParseKey(s string) (*Key, error) { | ||
| 37 | var privBytes []byte | ||
| 38 | |||
| 39 | if strings.HasPrefix(s, "nsec1") { | ||
| 40 | hrp, data, err := Bech32Decode(s) | ||
| 41 | if err != nil { | ||
| 42 | return nil, fmt.Errorf("invalid nsec: %w", err) | ||
| 43 | } | ||
| 44 | if hrp != "nsec" { | ||
| 45 | return nil, fmt.Errorf("invalid prefix: expected nsec, got %s", hrp) | ||
| 46 | } | ||
| 47 | if len(data) != 32 { | ||
| 48 | return nil, fmt.Errorf("invalid nsec data length: %d", len(data)) | ||
| 49 | } | ||
| 50 | privBytes = data | ||
| 51 | } else { | ||
| 52 | var err error | ||
| 53 | privBytes, err = hex.DecodeString(s) | ||
| 54 | if err != nil { | ||
| 55 | return nil, fmt.Errorf("invalid hex: %w", err) | ||
| 56 | } | ||
| 57 | } | ||
| 58 | |||
| 59 | if len(privBytes) != 32 { | ||
| 60 | return nil, fmt.Errorf("private key must be 32 bytes, got %d", len(privBytes)) | ||
| 61 | } | ||
| 62 | |||
| 63 | priv, _ := btcec.PrivKeyFromBytes(privBytes) | ||
| 64 | return &Key{ | ||
| 65 | priv: priv, | ||
| 66 | pub: priv.PubKey(), | ||
| 67 | }, nil | ||
| 68 | } | ||
| 69 | |||
| 70 | // ParsePublicKey parses a public key from hex or npub (bech32) format. | ||
| 71 | // The returned Key can only verify, not sign. | ||
| 72 | func ParsePublicKey(s string) (*Key, error) { | ||
| 73 | var pubBytes []byte | ||
| 74 | |||
| 75 | if strings.HasPrefix(s, "npub1") { | ||
| 76 | hrp, data, err := Bech32Decode(s) | ||
| 77 | if err != nil { | ||
| 78 | return nil, fmt.Errorf("invalid npub: %w", err) | ||
| 79 | } | ||
| 80 | if hrp != "npub" { | ||
| 81 | return nil, fmt.Errorf("invalid prefix: expected npub, got %s", hrp) | ||
| 82 | } | ||
| 83 | if len(data) != 32 { | ||
| 84 | return nil, fmt.Errorf("invalid npub data length: %d", len(data)) | ||
| 85 | } | ||
| 86 | pubBytes = data | ||
| 87 | } else { | ||
| 88 | var err error | ||
| 89 | pubBytes, err = hex.DecodeString(s) | ||
| 90 | if err != nil { | ||
| 91 | return nil, fmt.Errorf("invalid hex: %w", err) | ||
| 92 | } | ||
| 93 | } | ||
| 94 | |||
| 95 | if len(pubBytes) != 32 { | ||
| 96 | return nil, fmt.Errorf("public key must be 32 bytes, got %d", len(pubBytes)) | ||
| 97 | } | ||
| 98 | |||
| 99 | pub, err := schnorr.ParsePubKey(pubBytes) | ||
| 100 | if err != nil { | ||
| 101 | return nil, fmt.Errorf("invalid public key: %w", err) | ||
| 102 | } | ||
| 103 | |||
| 104 | return &Key{ | ||
| 105 | priv: nil, | ||
| 106 | pub: pub, | ||
| 107 | }, nil | ||
| 108 | } | ||
| 109 | |||
| 110 | // CanSign returns true if this key can sign events (has private key). | ||
| 111 | func (k *Key) CanSign() bool { | ||
| 112 | return k.priv != nil | ||
| 113 | } | ||
| 114 | |||
| 115 | // Public returns the public key as a 64-character hex string. | ||
| 116 | func (k *Key) Public() string { | ||
| 117 | return hex.EncodeToString(schnorr.SerializePubKey(k.pub)) | ||
| 118 | } | ||
| 119 | |||
| 120 | // Private returns the private key as a 64-character hex string. | ||
| 121 | // Returns empty string if this is a public-only key. | ||
| 122 | func (k *Key) Private() string { | ||
| 123 | if k.priv == nil { | ||
| 124 | return "" | ||
| 125 | } | ||
| 126 | return hex.EncodeToString(k.priv.Serialize()) | ||
| 127 | } | ||
| 128 | |||
| 129 | // Npub returns the public key in bech32 npub format. | ||
| 130 | func (k *Key) Npub() string { | ||
| 131 | pubBytes := schnorr.SerializePubKey(k.pub) | ||
| 132 | npub, _ := Bech32Encode("npub", pubBytes) | ||
| 133 | return npub | ||
| 134 | } | ||
| 135 | |||
| 136 | // Nsec returns the private key in bech32 nsec format. | ||
| 137 | // Returns empty string if this is a public-only key. | ||
| 138 | func (k *Key) Nsec() string { | ||
| 139 | if k.priv == nil { | ||
| 140 | return "" | ||
| 141 | } | ||
| 142 | nsec, _ := Bech32Encode("nsec", k.priv.Serialize()) | ||
| 143 | return nsec | ||
| 144 | } | ||
| 145 | |||
| 146 | // Sign signs the event with this key. | ||
| 147 | // Sets the PubKey, ID, and Sig fields on the event. | ||
| 148 | // Returns an error if this is a public-only key. | ||
| 149 | func (k *Key) Sign(event *Event) error { | ||
| 150 | if k.priv == nil { | ||
| 151 | return fmt.Errorf("cannot sign: public-only key") | ||
| 152 | } | ||
| 153 | |||
| 154 | // Set public key | ||
| 155 | event.PubKey = k.Public() | ||
| 156 | |||
| 157 | if event.CreatedAt == 0 { | ||
| 158 | event.CreatedAt = time.Now().Unix() | ||
| 159 | } | ||
| 160 | |||
| 161 | // Compute ID | ||
| 162 | event.SetID() | ||
| 163 | |||
| 164 | // Hash the ID for signing | ||
| 165 | idBytes, err := hex.DecodeString(event.ID) | ||
| 166 | if err != nil { | ||
| 167 | return fmt.Errorf("failed to decode event ID: %w", err) | ||
| 168 | } | ||
| 169 | |||
| 170 | // Sign with Schnorr | ||
| 171 | sig, err := schnorr.Sign(k.priv, idBytes) | ||
| 172 | if err != nil { | ||
| 173 | return fmt.Errorf("failed to sign event: %w", err) | ||
| 174 | } | ||
| 175 | |||
| 176 | event.Sig = hex.EncodeToString(sig.Serialize()) | ||
| 177 | return nil | ||
| 178 | } | ||
| 179 | |||
| 180 | // Verify verifies the event signature. | ||
| 181 | // Returns true if the signature is valid, false otherwise. | ||
| 182 | func (e *Event) Verify() bool { | ||
| 183 | // Verify ID first | ||
| 184 | if !e.CheckID() { | ||
| 185 | return false | ||
| 186 | } | ||
| 187 | |||
| 188 | // Decode public key | ||
| 189 | pubKeyBytes, err := hex.DecodeString(e.PubKey) | ||
| 190 | if err != nil || len(pubKeyBytes) != 32 { | ||
| 191 | return false | ||
| 192 | } | ||
| 193 | |||
| 194 | pubKey, err := schnorr.ParsePubKey(pubKeyBytes) | ||
| 195 | if err != nil { | ||
| 196 | return false | ||
| 197 | } | ||
| 198 | |||
| 199 | // Decode signature | ||
| 200 | sigBytes, err := hex.DecodeString(e.Sig) | ||
| 201 | if err != nil { | ||
| 202 | return false | ||
| 203 | } | ||
| 204 | |||
| 205 | sig, err := schnorr.ParseSignature(sigBytes) | ||
| 206 | if err != nil { | ||
| 207 | return false | ||
| 208 | } | ||
| 209 | |||
| 210 | // Decode ID (message hash) | ||
| 211 | idBytes, err := hex.DecodeString(e.ID) | ||
| 212 | if err != nil { | ||
| 213 | return false | ||
| 214 | } | ||
| 215 | |||
| 216 | return sig.Verify(idBytes, pubKey) | ||
| 217 | } | ||
diff --git a/internal/nostr/keys_test.go b/internal/nostr/keys_test.go new file mode 100644 index 0000000..6c3dd3d --- /dev/null +++ b/internal/nostr/keys_test.go | |||
| @@ -0,0 +1,333 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "encoding/hex" | ||
| 5 | "strings" | ||
| 6 | "testing" | ||
| 7 | ) | ||
| 8 | |||
| 9 | func TestGenerateKey(t *testing.T) { | ||
| 10 | key1, err := GenerateKey() | ||
| 11 | if err != nil { | ||
| 12 | t.Fatalf("GenerateKey() error = %v", err) | ||
| 13 | } | ||
| 14 | |||
| 15 | if !key1.CanSign() { | ||
| 16 | t.Error("Generated key should be able to sign") | ||
| 17 | } | ||
| 18 | |||
| 19 | // Private key should be 64 hex characters | ||
| 20 | if len(key1.Private()) != 64 { | ||
| 21 | t.Errorf("Private() length = %d, want 64", len(key1.Private())) | ||
| 22 | } | ||
| 23 | |||
| 24 | // Public key should be 64 hex characters | ||
| 25 | if len(key1.Public()) != 64 { | ||
| 26 | t.Errorf("Public() length = %d, want 64", len(key1.Public())) | ||
| 27 | } | ||
| 28 | |||
| 29 | // Should be valid hex | ||
| 30 | if _, err := hex.DecodeString(key1.Private()); err != nil { | ||
| 31 | t.Errorf("Private() is not valid hex: %v", err) | ||
| 32 | } | ||
| 33 | if _, err := hex.DecodeString(key1.Public()); err != nil { | ||
| 34 | t.Errorf("Public() is not valid hex: %v", err) | ||
| 35 | } | ||
| 36 | |||
| 37 | // Keys should be unique | ||
| 38 | key2, err := GenerateKey() | ||
| 39 | if err != nil { | ||
| 40 | t.Fatalf("GenerateKey() second call error = %v", err) | ||
| 41 | } | ||
| 42 | if key1.Private() == key2.Private() { | ||
| 43 | t.Error("GenerateKey() returned same private key twice") | ||
| 44 | } | ||
| 45 | } | ||
| 46 | |||
| 47 | func TestKeyNpubNsec(t *testing.T) { | ||
| 48 | key, err := GenerateKey() | ||
| 49 | if err != nil { | ||
| 50 | t.Fatalf("GenerateKey() error = %v", err) | ||
| 51 | } | ||
| 52 | |||
| 53 | npub := key.Npub() | ||
| 54 | nsec := key.Nsec() | ||
| 55 | |||
| 56 | // Check prefixes | ||
| 57 | if !strings.HasPrefix(npub, "npub1") { | ||
| 58 | t.Errorf("Npub() = %s, want prefix 'npub1'", npub) | ||
| 59 | } | ||
| 60 | if !strings.HasPrefix(nsec, "nsec1") { | ||
| 61 | t.Errorf("Nsec() = %s, want prefix 'nsec1'", nsec) | ||
| 62 | } | ||
| 63 | |||
| 64 | // Should be able to parse them back | ||
| 65 | keyFromNsec, err := ParseKey(nsec) | ||
| 66 | if err != nil { | ||
| 67 | t.Fatalf("ParseKey(nsec) error = %v", err) | ||
| 68 | } | ||
| 69 | if keyFromNsec.Private() != key.Private() { | ||
| 70 | t.Error("ParseKey(nsec) did not restore original private key") | ||
| 71 | } | ||
| 72 | |||
| 73 | keyFromNpub, err := ParsePublicKey(npub) | ||
| 74 | if err != nil { | ||
| 75 | t.Fatalf("ParsePublicKey(npub) error = %v", err) | ||
| 76 | } | ||
| 77 | if keyFromNpub.Public() != key.Public() { | ||
| 78 | t.Error("ParsePublicKey(npub) did not restore original public key") | ||
| 79 | } | ||
| 80 | } | ||
| 81 | |||
| 82 | func TestParseKey(t *testing.T) { | ||
| 83 | // Known test vector | ||
| 84 | hexKey := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" | ||
| 85 | |||
| 86 | key, err := ParseKey(hexKey) | ||
| 87 | if err != nil { | ||
| 88 | t.Fatalf("ParseKey(hex) error = %v", err) | ||
| 89 | } | ||
| 90 | |||
| 91 | if !key.CanSign() { | ||
| 92 | t.Error("ParseKey should return key that can sign") | ||
| 93 | } | ||
| 94 | |||
| 95 | if key.Private() != hexKey { | ||
| 96 | t.Errorf("Private() = %s, want %s", key.Private(), hexKey) | ||
| 97 | } | ||
| 98 | |||
| 99 | // Parse the nsec back | ||
| 100 | nsec := key.Nsec() | ||
| 101 | key2, err := ParseKey(nsec) | ||
| 102 | if err != nil { | ||
| 103 | t.Fatalf("ParseKey(nsec) error = %v", err) | ||
| 104 | } | ||
| 105 | if key2.Private() != hexKey { | ||
| 106 | t.Error("Round-trip through nsec failed") | ||
| 107 | } | ||
| 108 | } | ||
| 109 | |||
| 110 | func TestParseKeyErrors(t *testing.T) { | ||
| 111 | tests := []struct { | ||
| 112 | name string | ||
| 113 | key string | ||
| 114 | }{ | ||
| 115 | {"invalid hex", "not-hex"}, | ||
| 116 | {"too short", "0123456789abcdef"}, | ||
| 117 | {"too long", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef00"}, | ||
| 118 | {"invalid nsec", "nsec1invalid"}, | ||
| 119 | } | ||
| 120 | |||
| 121 | for _, tt := range tests { | ||
| 122 | t.Run(tt.name, func(t *testing.T) { | ||
| 123 | _, err := ParseKey(tt.key) | ||
| 124 | if err == nil { | ||
| 125 | t.Error("ParseKey() expected error, got nil") | ||
| 126 | } | ||
| 127 | }) | ||
| 128 | } | ||
| 129 | } | ||
| 130 | |||
| 131 | func TestParsePublicKey(t *testing.T) { | ||
| 132 | // Generate a key and extract public | ||
| 133 | fullKey, _ := GenerateKey() | ||
| 134 | pubHex := fullKey.Public() | ||
| 135 | |||
| 136 | // Parse public key from hex | ||
| 137 | key, err := ParsePublicKey(pubHex) | ||
| 138 | if err != nil { | ||
| 139 | t.Fatalf("ParsePublicKey(hex) error = %v", err) | ||
| 140 | } | ||
| 141 | |||
| 142 | if key.CanSign() { | ||
| 143 | t.Error("ParsePublicKey should return key that cannot sign") | ||
| 144 | } | ||
| 145 | |||
| 146 | if key.Public() != pubHex { | ||
| 147 | t.Errorf("Public() = %s, want %s", key.Public(), pubHex) | ||
| 148 | } | ||
| 149 | |||
| 150 | if key.Private() != "" { | ||
| 151 | t.Error("Private() should return empty string for public-only key") | ||
| 152 | } | ||
| 153 | |||
| 154 | if key.Nsec() != "" { | ||
| 155 | t.Error("Nsec() should return empty string for public-only key") | ||
| 156 | } | ||
| 157 | |||
| 158 | // Parse from npub | ||
| 159 | npub := fullKey.Npub() | ||
| 160 | key2, err := ParsePublicKey(npub) | ||
| 161 | if err != nil { | ||
| 162 | t.Fatalf("ParsePublicKey(npub) error = %v", err) | ||
| 163 | } | ||
| 164 | if key2.Public() != pubHex { | ||
| 165 | t.Error("ParsePublicKey(npub) did not restore correct public key") | ||
| 166 | } | ||
| 167 | } | ||
| 168 | |||
| 169 | func TestParsePublicKeyErrors(t *testing.T) { | ||
| 170 | tests := []struct { | ||
| 171 | name string | ||
| 172 | key string | ||
| 173 | }{ | ||
| 174 | {"invalid hex", "not-hex"}, | ||
| 175 | {"too short", "0123456789abcdef"}, | ||
| 176 | {"invalid npub", "npub1invalid"}, | ||
| 177 | } | ||
| 178 | |||
| 179 | for _, tt := range tests { | ||
| 180 | t.Run(tt.name, func(t *testing.T) { | ||
| 181 | _, err := ParsePublicKey(tt.key) | ||
| 182 | if err == nil { | ||
| 183 | t.Error("ParsePublicKey() expected error, got nil") | ||
| 184 | } | ||
| 185 | }) | ||
| 186 | } | ||
| 187 | } | ||
| 188 | |||
| 189 | func TestKeySign(t *testing.T) { | ||
| 190 | key, err := GenerateKey() | ||
| 191 | if err != nil { | ||
| 192 | t.Fatalf("GenerateKey() error = %v", err) | ||
| 193 | } | ||
| 194 | |||
| 195 | event := &Event{ | ||
| 196 | CreatedAt: 1704067200, | ||
| 197 | Kind: 1, | ||
| 198 | Tags: Tags{}, | ||
| 199 | Content: "Test message", | ||
| 200 | } | ||
| 201 | |||
| 202 | if err := key.Sign(event); err != nil { | ||
| 203 | t.Fatalf("Sign() error = %v", err) | ||
| 204 | } | ||
| 205 | |||
| 206 | // Check that all fields are set | ||
| 207 | if event.PubKey == "" { | ||
| 208 | t.Error("Sign() did not set PubKey") | ||
| 209 | } | ||
| 210 | if event.ID == "" { | ||
| 211 | t.Error("Sign() did not set ID") | ||
| 212 | } | ||
| 213 | if event.Sig == "" { | ||
| 214 | t.Error("Sign() did not set Sig") | ||
| 215 | } | ||
| 216 | |||
| 217 | // PubKey should match | ||
| 218 | if event.PubKey != key.Public() { | ||
| 219 | t.Errorf("PubKey = %s, want %s", event.PubKey, key.Public()) | ||
| 220 | } | ||
| 221 | |||
| 222 | // Signature should be 128 hex characters (64 bytes) | ||
| 223 | if len(event.Sig) != 128 { | ||
| 224 | t.Errorf("Signature length = %d, want 128", len(event.Sig)) | ||
| 225 | } | ||
| 226 | } | ||
| 227 | |||
| 228 | func TestKeySignPublicOnlyError(t *testing.T) { | ||
| 229 | fullKey, _ := GenerateKey() | ||
| 230 | pubOnlyKey, _ := ParsePublicKey(fullKey.Public()) | ||
| 231 | |||
| 232 | event := &Event{ | ||
| 233 | CreatedAt: 1704067200, | ||
| 234 | Kind: 1, | ||
| 235 | Tags: Tags{}, | ||
| 236 | Content: "Test", | ||
| 237 | } | ||
| 238 | |||
| 239 | err := pubOnlyKey.Sign(event) | ||
| 240 | if err == nil { | ||
| 241 | t.Error("Sign() with public-only key should return error") | ||
| 242 | } | ||
| 243 | } | ||
| 244 | |||
| 245 | func TestEventVerify(t *testing.T) { | ||
| 246 | key, err := GenerateKey() | ||
| 247 | if err != nil { | ||
| 248 | t.Fatalf("GenerateKey() error = %v", err) | ||
| 249 | } | ||
| 250 | |||
| 251 | event := &Event{ | ||
| 252 | CreatedAt: 1704067200, | ||
| 253 | Kind: 1, | ||
| 254 | Tags: Tags{{"test", "value"}}, | ||
| 255 | Content: "Test message for verification", | ||
| 256 | } | ||
| 257 | |||
| 258 | if err := key.Sign(event); err != nil { | ||
| 259 | t.Fatalf("Sign() error = %v", err) | ||
| 260 | } | ||
| 261 | |||
| 262 | if !event.Verify() { | ||
| 263 | t.Error("Verify() returned false for valid signature") | ||
| 264 | } | ||
| 265 | } | ||
| 266 | |||
| 267 | func TestEventVerifyInvalid(t *testing.T) { | ||
| 268 | key, err := GenerateKey() | ||
| 269 | if err != nil { | ||
| 270 | t.Fatalf("GenerateKey() error = %v", err) | ||
| 271 | } | ||
| 272 | |||
| 273 | event := &Event{ | ||
| 274 | CreatedAt: 1704067200, | ||
| 275 | Kind: 1, | ||
| 276 | Tags: Tags{}, | ||
| 277 | Content: "Test message", | ||
| 278 | } | ||
| 279 | |||
| 280 | if err := key.Sign(event); err != nil { | ||
| 281 | t.Fatalf("Sign() error = %v", err) | ||
| 282 | } | ||
| 283 | |||
| 284 | // Corrupt the content (ID becomes invalid) | ||
| 285 | event.Content = "Modified content" | ||
| 286 | if event.Verify() { | ||
| 287 | t.Error("Verify() returned true for modified content") | ||
| 288 | } | ||
| 289 | |||
| 290 | // Restore content but corrupt signature | ||
| 291 | event.Content = "Test message" | ||
| 292 | event.SetID() | ||
| 293 | event.Sig = "0000000000000000000000000000000000000000000000000000000000000000" + | ||
| 294 | "0000000000000000000000000000000000000000000000000000000000000000" | ||
| 295 | if event.Verify() { | ||
| 296 | t.Error("Verify() returned true for invalid signature") | ||
| 297 | } | ||
| 298 | } | ||
| 299 | |||
| 300 | func TestSignAndVerifyRoundTrip(t *testing.T) { | ||
| 301 | // Generate key | ||
| 302 | key, err := GenerateKey() | ||
| 303 | if err != nil { | ||
| 304 | t.Fatalf("GenerateKey() error = %v", err) | ||
| 305 | } | ||
| 306 | |||
| 307 | // Create and sign event | ||
| 308 | event := &Event{ | ||
| 309 | CreatedAt: 1704067200, | ||
| 310 | Kind: KindTextNote, | ||
| 311 | Tags: Tags{{"t", "test"}}, | ||
| 312 | Content: "Integration test message", | ||
| 313 | } | ||
| 314 | |||
| 315 | if err := key.Sign(event); err != nil { | ||
| 316 | t.Fatalf("Sign() error = %v", err) | ||
| 317 | } | ||
| 318 | |||
| 319 | // Verify public key matches | ||
| 320 | if event.PubKey != key.Public() { | ||
| 321 | t.Errorf("Signed event PubKey = %s, want %s", event.PubKey, key.Public()) | ||
| 322 | } | ||
| 323 | |||
| 324 | // Verify the signature | ||
| 325 | if !event.Verify() { | ||
| 326 | t.Error("Verify() failed for freshly signed event") | ||
| 327 | } | ||
| 328 | |||
| 329 | // Check ID is correct | ||
| 330 | if !event.CheckID() { | ||
| 331 | t.Error("CheckID() failed for freshly signed event") | ||
| 332 | } | ||
| 333 | } | ||
diff --git a/internal/nostr/kinds.go b/internal/nostr/kinds.go new file mode 100644 index 0000000..cb76e88 --- /dev/null +++ b/internal/nostr/kinds.go | |||
| @@ -0,0 +1,51 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | // Event kind constants as defined in NIP-01 and related NIPs. | ||
| 4 | const ( | ||
| 5 | KindMetadata = 0 | ||
| 6 | KindTextNote = 1 | ||
| 7 | KindContactList = 3 | ||
| 8 | KindEncryptedDM = 4 | ||
| 9 | KindDeletion = 5 | ||
| 10 | KindRepost = 6 | ||
| 11 | KindReaction = 7 | ||
| 12 | ) | ||
| 13 | |||
| 14 | // IsRegular returns true if the kind is a regular event (stored, not replaced). | ||
| 15 | // Regular events: 1000 <= kind < 10000 or kind in {0,1,2,...} except replaceable ones. | ||
| 16 | func IsRegular(kind int) bool { | ||
| 17 | if kind == KindMetadata || kind == KindContactList { | ||
| 18 | return false | ||
| 19 | } | ||
| 20 | if kind >= 10000 && kind < 20000 { | ||
| 21 | return false // replaceable | ||
| 22 | } | ||
| 23 | if kind >= 20000 && kind < 30000 { | ||
| 24 | return false // ephemeral | ||
| 25 | } | ||
| 26 | if kind >= 30000 && kind < 40000 { | ||
| 27 | return false // addressable | ||
| 28 | } | ||
| 29 | return true | ||
| 30 | } | ||
| 31 | |||
| 32 | // IsReplaceable returns true if the kind is replaceable (NIP-01). | ||
| 33 | // Replaceable events: 10000 <= kind < 20000, or kind 0 (metadata) or kind 3 (contact list). | ||
| 34 | func IsReplaceable(kind int) bool { | ||
| 35 | if kind == KindMetadata || kind == KindContactList { | ||
| 36 | return true | ||
| 37 | } | ||
| 38 | return kind >= 10000 && kind < 20000 | ||
| 39 | } | ||
| 40 | |||
| 41 | // IsEphemeral returns true if the kind is ephemeral (not stored). | ||
| 42 | // Ephemeral events: 20000 <= kind < 30000. | ||
| 43 | func IsEphemeral(kind int) bool { | ||
| 44 | return kind >= 20000 && kind < 30000 | ||
| 45 | } | ||
| 46 | |||
| 47 | // IsAddressable returns true if the kind is addressable (parameterized replaceable). | ||
| 48 | // Addressable events: 30000 <= kind < 40000. | ||
| 49 | func IsAddressable(kind int) bool { | ||
| 50 | return kind >= 30000 && kind < 40000 | ||
| 51 | } | ||
diff --git a/internal/nostr/kinds_test.go b/internal/nostr/kinds_test.go new file mode 100644 index 0000000..2bf013d --- /dev/null +++ b/internal/nostr/kinds_test.go | |||
| @@ -0,0 +1,128 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "testing" | ||
| 5 | ) | ||
| 6 | |||
| 7 | func TestKindConstants(t *testing.T) { | ||
| 8 | // Verify constants match NIP-01 spec | ||
| 9 | tests := []struct { | ||
| 10 | name string | ||
| 11 | kind int | ||
| 12 | value int | ||
| 13 | }{ | ||
| 14 | {"Metadata", KindMetadata, 0}, | ||
| 15 | {"TextNote", KindTextNote, 1}, | ||
| 16 | {"ContactList", KindContactList, 3}, | ||
| 17 | {"EncryptedDM", KindEncryptedDM, 4}, | ||
| 18 | {"Deletion", KindDeletion, 5}, | ||
| 19 | {"Repost", KindRepost, 6}, | ||
| 20 | {"Reaction", KindReaction, 7}, | ||
| 21 | } | ||
| 22 | |||
| 23 | for _, tt := range tests { | ||
| 24 | t.Run(tt.name, func(t *testing.T) { | ||
| 25 | if tt.kind != tt.value { | ||
| 26 | t.Errorf("Kind%s = %d, want %d", tt.name, tt.kind, tt.value) | ||
| 27 | } | ||
| 28 | }) | ||
| 29 | } | ||
| 30 | } | ||
| 31 | |||
| 32 | func TestIsRegular(t *testing.T) { | ||
| 33 | tests := []struct { | ||
| 34 | kind int | ||
| 35 | want bool | ||
| 36 | }{ | ||
| 37 | {0, false}, // Metadata - replaceable | ||
| 38 | {1, true}, // TextNote - regular | ||
| 39 | {3, false}, // ContactList - replaceable | ||
| 40 | {4, true}, // EncryptedDM - regular | ||
| 41 | {5, true}, // Deletion - regular | ||
| 42 | {1000, true}, // Regular range | ||
| 43 | {9999, true}, // Regular range | ||
| 44 | {10000, false}, // Replaceable range | ||
| 45 | {19999, false}, // Replaceable range | ||
| 46 | {20000, false}, // Ephemeral range | ||
| 47 | {29999, false}, // Ephemeral range | ||
| 48 | {30000, false}, // Addressable range | ||
| 49 | {39999, false}, // Addressable range | ||
| 50 | {40000, true}, // Back to regular | ||
| 51 | } | ||
| 52 | |||
| 53 | for _, tt := range tests { | ||
| 54 | t.Run("kind_"+string(rune(tt.kind)), func(t *testing.T) { | ||
| 55 | if got := IsRegular(tt.kind); got != tt.want { | ||
| 56 | t.Errorf("IsRegular(%d) = %v, want %v", tt.kind, got, tt.want) | ||
| 57 | } | ||
| 58 | }) | ||
| 59 | } | ||
| 60 | } | ||
| 61 | |||
| 62 | func TestIsReplaceable(t *testing.T) { | ||
| 63 | tests := []struct { | ||
| 64 | kind int | ||
| 65 | want bool | ||
| 66 | }{ | ||
| 67 | {0, true}, // Metadata | ||
| 68 | {1, false}, // TextNote | ||
| 69 | {3, true}, // ContactList | ||
| 70 | {10000, true}, // Replaceable range start | ||
| 71 | {15000, true}, // Replaceable range middle | ||
| 72 | {19999, true}, // Replaceable range end | ||
| 73 | {20000, false}, // Ephemeral range | ||
| 74 | {30000, false}, // Addressable range | ||
| 75 | } | ||
| 76 | |||
| 77 | for _, tt := range tests { | ||
| 78 | t.Run("kind_"+string(rune(tt.kind)), func(t *testing.T) { | ||
| 79 | if got := IsReplaceable(tt.kind); got != tt.want { | ||
| 80 | t.Errorf("IsReplaceable(%d) = %v, want %v", tt.kind, got, tt.want) | ||
| 81 | } | ||
| 82 | }) | ||
| 83 | } | ||
| 84 | } | ||
| 85 | |||
| 86 | func TestIsEphemeral(t *testing.T) { | ||
| 87 | tests := []struct { | ||
| 88 | kind int | ||
| 89 | want bool | ||
| 90 | }{ | ||
| 91 | {1, false}, // TextNote | ||
| 92 | {19999, false}, // Replaceable range | ||
| 93 | {20000, true}, // Ephemeral range start | ||
| 94 | {25000, true}, // Ephemeral range middle | ||
| 95 | {29999, true}, // Ephemeral range end | ||
| 96 | {30000, false}, // Addressable range | ||
| 97 | } | ||
| 98 | |||
| 99 | for _, tt := range tests { | ||
| 100 | t.Run("kind_"+string(rune(tt.kind)), func(t *testing.T) { | ||
| 101 | if got := IsEphemeral(tt.kind); got != tt.want { | ||
| 102 | t.Errorf("IsEphemeral(%d) = %v, want %v", tt.kind, got, tt.want) | ||
| 103 | } | ||
| 104 | }) | ||
| 105 | } | ||
| 106 | } | ||
| 107 | |||
| 108 | func TestIsAddressable(t *testing.T) { | ||
| 109 | tests := []struct { | ||
| 110 | kind int | ||
| 111 | want bool | ||
| 112 | }{ | ||
| 113 | {1, false}, // TextNote | ||
| 114 | {29999, false}, // Ephemeral range | ||
| 115 | {30000, true}, // Addressable range start | ||
| 116 | {35000, true}, // Addressable range middle | ||
| 117 | {39999, true}, // Addressable range end | ||
| 118 | {40000, false}, // Beyond addressable range | ||
| 119 | } | ||
| 120 | |||
| 121 | for _, tt := range tests { | ||
| 122 | t.Run("kind_"+string(rune(tt.kind)), func(t *testing.T) { | ||
| 123 | if got := IsAddressable(tt.kind); got != tt.want { | ||
| 124 | t.Errorf("IsAddressable(%d) = %v, want %v", tt.kind, got, tt.want) | ||
| 125 | } | ||
| 126 | }) | ||
| 127 | } | ||
| 128 | } | ||
diff --git a/internal/nostr/relay.go b/internal/nostr/relay.go new file mode 100644 index 0000000..2b156e0 --- /dev/null +++ b/internal/nostr/relay.go | |||
| @@ -0,0 +1,305 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "context" | ||
| 5 | "crypto/rand" | ||
| 6 | "fmt" | ||
| 7 | "sync" | ||
| 8 | |||
| 9 | "northwest.io/nostr-grpc/internal/websocket" | ||
| 10 | ) | ||
| 11 | |||
| 12 | // Relay represents a connection to a Nostr relay. | ||
| 13 | type Relay struct { | ||
| 14 | URL string | ||
| 15 | conn *websocket.Conn | ||
| 16 | mu sync.Mutex | ||
| 17 | |||
| 18 | subscriptions map[string]*Subscription | ||
| 19 | subscriptionsMu sync.RWMutex | ||
| 20 | |||
| 21 | okChannels map[string]chan *OKEnvelope | ||
| 22 | okChannelsMu sync.Mutex | ||
| 23 | } | ||
| 24 | |||
| 25 | // Connect establishes a WebSocket connection to the relay. | ||
| 26 | func Connect(ctx context.Context, url string) (*Relay, error) { | ||
| 27 | conn, err := websocket.Dial(ctx, url) | ||
| 28 | if err != nil { | ||
| 29 | return nil, fmt.Errorf("failed to connect to relay: %w", err) | ||
| 30 | } | ||
| 31 | |||
| 32 | r := &Relay{ | ||
| 33 | URL: url, | ||
| 34 | conn: conn, | ||
| 35 | subscriptions: make(map[string]*Subscription), | ||
| 36 | okChannels: make(map[string]chan *OKEnvelope), | ||
| 37 | } | ||
| 38 | |||
| 39 | go r.Listen(ctx) | ||
| 40 | |||
| 41 | return r, nil | ||
| 42 | } | ||
| 43 | |||
| 44 | // Close closes the WebSocket connection. | ||
| 45 | func (r *Relay) Close() error { | ||
| 46 | r.mu.Lock() | ||
| 47 | defer r.mu.Unlock() | ||
| 48 | |||
| 49 | if r.conn == nil { | ||
| 50 | return nil | ||
| 51 | } | ||
| 52 | |||
| 53 | err := r.conn.Close(websocket.StatusNormalClosure, "") | ||
| 54 | r.conn = nil | ||
| 55 | return err | ||
| 56 | } | ||
| 57 | |||
| 58 | // Send sends an envelope to the relay. | ||
| 59 | func (r *Relay) Send(ctx context.Context, env Envelope) error { | ||
| 60 | data, err := env.MarshalJSON() | ||
| 61 | if err != nil { | ||
| 62 | return fmt.Errorf("failed to marshal envelope: %w", err) | ||
| 63 | } | ||
| 64 | |||
| 65 | r.mu.Lock() | ||
| 66 | defer r.mu.Unlock() | ||
| 67 | |||
| 68 | if r.conn == nil { | ||
| 69 | return fmt.Errorf("connection closed") | ||
| 70 | } | ||
| 71 | |||
| 72 | return r.conn.Write(ctx, websocket.MessageText, data) | ||
| 73 | } | ||
| 74 | |||
| 75 | // Receive reads the next envelope from the relay. | ||
| 76 | func (r *Relay) Receive(ctx context.Context) (Envelope, error) { | ||
| 77 | r.mu.Lock() | ||
| 78 | conn := r.conn | ||
| 79 | r.mu.Unlock() | ||
| 80 | |||
| 81 | if conn == nil { | ||
| 82 | return nil, fmt.Errorf("connection closed") | ||
| 83 | } | ||
| 84 | |||
| 85 | _, data, err := conn.Read(ctx) | ||
| 86 | if err != nil { | ||
| 87 | return nil, fmt.Errorf("failed to read message: %w", err) | ||
| 88 | } | ||
| 89 | |||
| 90 | return ParseEnvelope(data) | ||
| 91 | } | ||
| 92 | |||
| 93 | // Publish sends an event to the relay and waits for the OK response. | ||
| 94 | func (r *Relay) Publish(ctx context.Context, event *Event) error { | ||
| 95 | ch := make(chan *OKEnvelope, 1) | ||
| 96 | |||
| 97 | r.okChannelsMu.Lock() | ||
| 98 | r.okChannels[event.ID] = ch | ||
| 99 | r.okChannelsMu.Unlock() | ||
| 100 | |||
| 101 | defer func() { | ||
| 102 | r.okChannelsMu.Lock() | ||
| 103 | delete(r.okChannels, event.ID) | ||
| 104 | r.okChannelsMu.Unlock() | ||
| 105 | }() | ||
| 106 | |||
| 107 | env := EventEnvelope{Event: event} | ||
| 108 | if err := r.Send(ctx, env); err != nil { | ||
| 109 | return fmt.Errorf("failed to send event: %w", err) | ||
| 110 | } | ||
| 111 | |||
| 112 | select { | ||
| 113 | case ok := <-ch: | ||
| 114 | if !ok.OK { | ||
| 115 | return fmt.Errorf("event rejected: %s", ok.Message) | ||
| 116 | } | ||
| 117 | return nil | ||
| 118 | case <-ctx.Done(): | ||
| 119 | return ctx.Err() | ||
| 120 | } | ||
| 121 | } | ||
| 122 | |||
| 123 | func genID() string { | ||
| 124 | buf := make([]byte, 5) | ||
| 125 | rand.Read(buf) | ||
| 126 | return fmt.Sprintf("%x", buf) | ||
| 127 | } | ||
| 128 | |||
| 129 | // subscribe is the internal implementation for Subscribe and Fetch. | ||
| 130 | func (r *Relay) subscribe(ctx context.Context, closeOnEOSE bool, filters ...Filter) *Subscription { | ||
| 131 | id := genID() | ||
| 132 | |||
| 133 | sub := &Subscription{ | ||
| 134 | ID: id, | ||
| 135 | relay: r, | ||
| 136 | Filters: filters, | ||
| 137 | Events: make(chan *Event, 100), | ||
| 138 | closeOnEOSE: closeOnEOSE, | ||
| 139 | } | ||
| 140 | |||
| 141 | r.subscriptionsMu.Lock() | ||
| 142 | r.subscriptions[id] = sub | ||
| 143 | r.subscriptionsMu.Unlock() | ||
| 144 | |||
| 145 | go func() { | ||
| 146 | <-ctx.Done() | ||
| 147 | sub.stop(ctx.Err()) | ||
| 148 | r.subscriptionsMu.Lock() | ||
| 149 | delete(r.subscriptions, id) | ||
| 150 | r.subscriptionsMu.Unlock() | ||
| 151 | }() | ||
| 152 | |||
| 153 | env := ReqEnvelope{ | ||
| 154 | SubscriptionID: id, | ||
| 155 | Filters: filters, | ||
| 156 | } | ||
| 157 | if err := r.Send(ctx, env); err != nil { | ||
| 158 | r.subscriptionsMu.Lock() | ||
| 159 | delete(r.subscriptions, id) | ||
| 160 | r.subscriptionsMu.Unlock() | ||
| 161 | sub.stop(fmt.Errorf("failed to send subscription request: %w", err)) | ||
| 162 | } | ||
| 163 | |||
| 164 | return sub | ||
| 165 | } | ||
| 166 | |||
| 167 | // Subscribe creates a subscription with the given filters. | ||
| 168 | // Events are received on the Events channel until the context is cancelled. | ||
| 169 | // After EOSE (end of stored events), the subscription continues to receive | ||
| 170 | // real-time events per the Nostr protocol. | ||
| 171 | func (r *Relay) Subscribe(ctx context.Context, filters ...Filter) *Subscription { | ||
| 172 | return r.subscribe(ctx, false, filters...) | ||
| 173 | } | ||
| 174 | |||
| 175 | // Fetch creates a subscription that closes automatically when EOSE is received. | ||
| 176 | // Use this for one-shot queries where you only want stored events. | ||
| 177 | func (r *Relay) Fetch(ctx context.Context, filters ...Filter) *Subscription { | ||
| 178 | return r.subscribe(ctx, true, filters...) | ||
| 179 | } | ||
| 180 | |||
| 181 | // dispatchEnvelope routes incoming messages to the appropriate subscription. | ||
| 182 | func (r *Relay) dispatchEnvelope(env Envelope) { | ||
| 183 | switch e := env.(type) { | ||
| 184 | case *EventEnvelope: | ||
| 185 | r.subscriptionsMu.RLock() | ||
| 186 | sub, ok := r.subscriptions[e.SubscriptionID] | ||
| 187 | r.subscriptionsMu.RUnlock() | ||
| 188 | if ok { | ||
| 189 | sub.send(e.Event) | ||
| 190 | } | ||
| 191 | case *EOSEEnvelope: | ||
| 192 | r.subscriptionsMu.RLock() | ||
| 193 | sub, ok := r.subscriptions[e.SubscriptionID] | ||
| 194 | r.subscriptionsMu.RUnlock() | ||
| 195 | if ok && sub.closeOnEOSE { | ||
| 196 | r.subscriptionsMu.Lock() | ||
| 197 | delete(r.subscriptions, e.SubscriptionID) | ||
| 198 | r.subscriptionsMu.Unlock() | ||
| 199 | sub.stop(nil) | ||
| 200 | } | ||
| 201 | case *ClosedEnvelope: | ||
| 202 | r.subscriptionsMu.Lock() | ||
| 203 | sub, ok := r.subscriptions[e.SubscriptionID] | ||
| 204 | if ok { | ||
| 205 | delete(r.subscriptions, e.SubscriptionID) | ||
| 206 | } | ||
| 207 | r.subscriptionsMu.Unlock() | ||
| 208 | if ok { | ||
| 209 | sub.stop(fmt.Errorf("subscription closed by relay: %s", e.Message)) | ||
| 210 | } | ||
| 211 | case *OKEnvelope: | ||
| 212 | r.okChannelsMu.Lock() | ||
| 213 | ch, ok := r.okChannels[e.EventID] | ||
| 214 | r.okChannelsMu.Unlock() | ||
| 215 | if ok { | ||
| 216 | select { | ||
| 217 | case ch <- e: | ||
| 218 | default: | ||
| 219 | } | ||
| 220 | } | ||
| 221 | } | ||
| 222 | } | ||
| 223 | |||
| 224 | // Listen reads messages from the relay and dispatches them to subscriptions. | ||
| 225 | func (r *Relay) Listen(ctx context.Context) error { | ||
| 226 | defer func() { | ||
| 227 | r.subscriptionsMu.Lock() | ||
| 228 | subs := make([]*Subscription, 0, len(r.subscriptions)) | ||
| 229 | for id, sub := range r.subscriptions { | ||
| 230 | subs = append(subs, sub) | ||
| 231 | delete(r.subscriptions, id) | ||
| 232 | } | ||
| 233 | r.subscriptionsMu.Unlock() | ||
| 234 | |||
| 235 | for _, sub := range subs { | ||
| 236 | sub.stop(fmt.Errorf("connection closed")) | ||
| 237 | } | ||
| 238 | }() | ||
| 239 | |||
| 240 | for { | ||
| 241 | select { | ||
| 242 | case <-ctx.Done(): | ||
| 243 | return ctx.Err() | ||
| 244 | default: | ||
| 245 | } | ||
| 246 | |||
| 247 | env, err := r.Receive(ctx) | ||
| 248 | if err != nil { | ||
| 249 | return err | ||
| 250 | } | ||
| 251 | |||
| 252 | r.dispatchEnvelope(env) | ||
| 253 | } | ||
| 254 | } | ||
| 255 | |||
| 256 | // Subscription represents an active subscription to a relay. | ||
| 257 | type Subscription struct { | ||
| 258 | ID string | ||
| 259 | relay *Relay | ||
| 260 | Filters []Filter | ||
| 261 | Events chan *Event | ||
| 262 | Err error | ||
| 263 | |||
| 264 | closeOnEOSE bool | ||
| 265 | mu sync.Mutex | ||
| 266 | done bool | ||
| 267 | } | ||
| 268 | |||
| 269 | // send delivers an event to the subscription's Events channel. | ||
| 270 | func (s *Subscription) send(ev *Event) { | ||
| 271 | s.mu.Lock() | ||
| 272 | defer s.mu.Unlock() | ||
| 273 | if s.done { | ||
| 274 | return | ||
| 275 | } | ||
| 276 | select { | ||
| 277 | case s.Events <- ev: | ||
| 278 | default: | ||
| 279 | } | ||
| 280 | } | ||
| 281 | |||
| 282 | // stop closes the subscription's Events channel and sets the error. | ||
| 283 | // It is idempotent — only the first call has any effect. | ||
| 284 | func (s *Subscription) stop(err error) { | ||
| 285 | s.mu.Lock() | ||
| 286 | defer s.mu.Unlock() | ||
| 287 | if s.done { | ||
| 288 | return | ||
| 289 | } | ||
| 290 | s.done = true | ||
| 291 | s.Err = err | ||
| 292 | close(s.Events) | ||
| 293 | } | ||
| 294 | |||
| 295 | // Close unsubscribes from the relay. | ||
| 296 | func (s *Subscription) Close(ctx context.Context) error { | ||
| 297 | s.stop(nil) | ||
| 298 | |||
| 299 | s.relay.subscriptionsMu.Lock() | ||
| 300 | delete(s.relay.subscriptions, s.ID) | ||
| 301 | s.relay.subscriptionsMu.Unlock() | ||
| 302 | |||
| 303 | env := CloseEnvelope{SubscriptionID: s.ID} | ||
| 304 | return s.relay.Send(ctx, env) | ||
| 305 | } | ||
diff --git a/internal/nostr/relay_test.go b/internal/nostr/relay_test.go new file mode 100644 index 0000000..02bd8e5 --- /dev/null +++ b/internal/nostr/relay_test.go | |||
| @@ -0,0 +1,326 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "context" | ||
| 5 | "encoding/json" | ||
| 6 | "net/http" | ||
| 7 | "net/http/httptest" | ||
| 8 | "strings" | ||
| 9 | "testing" | ||
| 10 | "time" | ||
| 11 | |||
| 12 | "northwest.io/nostr-grpc/internal/websocket" | ||
| 13 | ) | ||
| 14 | |||
| 15 | // mockRelay creates a test WebSocket server that echoes messages | ||
| 16 | func mockRelay(t *testing.T, handler func(conn *websocket.Conn)) *httptest.Server { | ||
| 17 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| 18 | conn, err := websocket.Accept(w, r) | ||
| 19 | if err != nil { | ||
| 20 | t.Logf("Failed to accept WebSocket: %v", err) | ||
| 21 | return | ||
| 22 | } | ||
| 23 | defer conn.Close(websocket.StatusNormalClosure, "") | ||
| 24 | |||
| 25 | handler(conn) | ||
| 26 | })) | ||
| 27 | } | ||
| 28 | |||
| 29 | func TestConnect(t *testing.T) { | ||
| 30 | server := mockRelay(t, func(conn *websocket.Conn) { | ||
| 31 | // Just accept and wait | ||
| 32 | time.Sleep(100 * time.Millisecond) | ||
| 33 | }) | ||
| 34 | defer server.Close() | ||
| 35 | |||
| 36 | url := "ws" + strings.TrimPrefix(server.URL, "http") | ||
| 37 | ctx := context.Background() | ||
| 38 | |||
| 39 | relay, err := Connect(ctx, url) | ||
| 40 | if err != nil { | ||
| 41 | t.Fatalf("Connect() error = %v", err) | ||
| 42 | } | ||
| 43 | defer relay.Close() | ||
| 44 | |||
| 45 | if relay.URL != url { | ||
| 46 | t.Errorf("Relay.URL = %s, want %s", relay.URL, url) | ||
| 47 | } | ||
| 48 | } | ||
| 49 | |||
| 50 | func TestConnectError(t *testing.T) { | ||
| 51 | ctx := context.Background() | ||
| 52 | _, err := Connect(ctx, "ws://localhost:99999") | ||
| 53 | if err == nil { | ||
| 54 | t.Error("Connect() expected error for invalid URL") | ||
| 55 | } | ||
| 56 | } | ||
| 57 | |||
| 58 | func TestRelaySendReceive(t *testing.T) { | ||
| 59 | server := mockRelay(t, func(conn *websocket.Conn) { | ||
| 60 | // Read message | ||
| 61 | _, data, err := conn.Read(context.Background()) | ||
| 62 | if err != nil { | ||
| 63 | t.Logf("Read error: %v", err) | ||
| 64 | return | ||
| 65 | } | ||
| 66 | |||
| 67 | // Echo it back as NOTICE | ||
| 68 | var arr []interface{} | ||
| 69 | json.Unmarshal(data, &arr) | ||
| 70 | |||
| 71 | response, _ := json.Marshal([]interface{}{"NOTICE", "received: " + arr[0].(string)}) | ||
| 72 | conn.Write(context.Background(), websocket.MessageText, response) | ||
| 73 | }) | ||
| 74 | defer server.Close() | ||
| 75 | |||
| 76 | url := "ws" + strings.TrimPrefix(server.URL, "http") | ||
| 77 | ctx := context.Background() | ||
| 78 | |||
| 79 | // Create relay without auto-Listen to test Send/Receive directly | ||
| 80 | conn, err := websocket.Dial(ctx, url) | ||
| 81 | if err != nil { | ||
| 82 | t.Fatalf("Dial() error = %v", err) | ||
| 83 | } | ||
| 84 | relay := &Relay{ | ||
| 85 | URL: url, | ||
| 86 | conn: conn, | ||
| 87 | subscriptions: make(map[string]*Subscription), | ||
| 88 | okChannels: make(map[string]chan *OKEnvelope), | ||
| 89 | } | ||
| 90 | defer relay.Close() | ||
| 91 | |||
| 92 | // Send a CLOSE envelope | ||
| 93 | closeEnv := CloseEnvelope{SubscriptionID: "test"} | ||
| 94 | if err := relay.Send(ctx, closeEnv); err != nil { | ||
| 95 | t.Fatalf("Send() error = %v", err) | ||
| 96 | } | ||
| 97 | |||
| 98 | // Receive response | ||
| 99 | env, err := relay.Receive(ctx) | ||
| 100 | if err != nil { | ||
| 101 | t.Fatalf("Receive() error = %v", err) | ||
| 102 | } | ||
| 103 | |||
| 104 | noticeEnv, ok := env.(*NoticeEnvelope) | ||
| 105 | if !ok { | ||
| 106 | t.Fatalf("Expected *NoticeEnvelope, got %T", env) | ||
| 107 | } | ||
| 108 | |||
| 109 | if !strings.Contains(noticeEnv.Message, "CLOSE") { | ||
| 110 | t.Errorf("Message = %s, want to contain 'CLOSE'", noticeEnv.Message) | ||
| 111 | } | ||
| 112 | } | ||
| 113 | |||
| 114 | func TestRelayPublish(t *testing.T) { | ||
| 115 | server := mockRelay(t, func(conn *websocket.Conn) { | ||
| 116 | // Read the EVENT message | ||
| 117 | _, data, err := conn.Read(context.Background()) | ||
| 118 | if err != nil { | ||
| 119 | t.Logf("Read error: %v", err) | ||
| 120 | return | ||
| 121 | } | ||
| 122 | |||
| 123 | // Parse to get event ID | ||
| 124 | var arr []json.RawMessage | ||
| 125 | json.Unmarshal(data, &arr) | ||
| 126 | |||
| 127 | var event Event | ||
| 128 | json.Unmarshal(arr[1], &event) | ||
| 129 | |||
| 130 | // Send OK response | ||
| 131 | response, _ := json.Marshal([]interface{}{"OK", event.ID, true, ""}) | ||
| 132 | conn.Write(context.Background(), websocket.MessageText, response) | ||
| 133 | }) | ||
| 134 | defer server.Close() | ||
| 135 | |||
| 136 | url := "ws" + strings.TrimPrefix(server.URL, "http") | ||
| 137 | ctx := context.Background() | ||
| 138 | |||
| 139 | relay, err := Connect(ctx, url) | ||
| 140 | if err != nil { | ||
| 141 | t.Fatalf("Connect() error = %v", err) | ||
| 142 | } | ||
| 143 | defer relay.Close() | ||
| 144 | |||
| 145 | // Create and sign event | ||
| 146 | key, _ := GenerateKey() | ||
| 147 | event := &Event{ | ||
| 148 | CreatedAt: time.Now().Unix(), | ||
| 149 | Kind: KindTextNote, | ||
| 150 | Tags: Tags{}, | ||
| 151 | Content: "Test publish", | ||
| 152 | } | ||
| 153 | key.Sign(event) | ||
| 154 | |||
| 155 | // Publish | ||
| 156 | if err := relay.Publish(ctx, event); err != nil { | ||
| 157 | t.Fatalf("Publish() error = %v", err) | ||
| 158 | } | ||
| 159 | } | ||
| 160 | |||
| 161 | func TestRelayPublishRejected(t *testing.T) { | ||
| 162 | server := mockRelay(t, func(conn *websocket.Conn) { | ||
| 163 | // Read the EVENT message | ||
| 164 | _, data, err := conn.Read(context.Background()) | ||
| 165 | if err != nil { | ||
| 166 | return | ||
| 167 | } | ||
| 168 | |||
| 169 | var arr []json.RawMessage | ||
| 170 | json.Unmarshal(data, &arr) | ||
| 171 | |||
| 172 | var event Event | ||
| 173 | json.Unmarshal(arr[1], &event) | ||
| 174 | |||
| 175 | // Send rejection | ||
| 176 | response, _ := json.Marshal([]interface{}{"OK", event.ID, false, "blocked: spam"}) | ||
| 177 | conn.Write(context.Background(), websocket.MessageText, response) | ||
| 178 | }) | ||
| 179 | defer server.Close() | ||
| 180 | |||
| 181 | url := "ws" + strings.TrimPrefix(server.URL, "http") | ||
| 182 | ctx := context.Background() | ||
| 183 | |||
| 184 | relay, err := Connect(ctx, url) | ||
| 185 | if err != nil { | ||
| 186 | t.Fatalf("Connect() error = %v", err) | ||
| 187 | } | ||
| 188 | defer relay.Close() | ||
| 189 | |||
| 190 | key, _ := GenerateKey() | ||
| 191 | event := &Event{ | ||
| 192 | CreatedAt: time.Now().Unix(), | ||
| 193 | Kind: KindTextNote, | ||
| 194 | Tags: Tags{}, | ||
| 195 | Content: "Test", | ||
| 196 | } | ||
| 197 | key.Sign(event) | ||
| 198 | |||
| 199 | err = relay.Publish(ctx, event) | ||
| 200 | if err == nil { | ||
| 201 | t.Error("Publish() expected error for rejected event") | ||
| 202 | } | ||
| 203 | if !strings.Contains(err.Error(), "rejected") { | ||
| 204 | t.Errorf("Error = %v, want to contain 'rejected'", err) | ||
| 205 | } | ||
| 206 | } | ||
| 207 | |||
| 208 | func TestRelaySubscribe(t *testing.T) { | ||
| 209 | server := mockRelay(t, func(conn *websocket.Conn) { | ||
| 210 | // Read REQ | ||
| 211 | _, data, err := conn.Read(context.Background()) | ||
| 212 | if err != nil { | ||
| 213 | return | ||
| 214 | } | ||
| 215 | |||
| 216 | var arr []json.RawMessage | ||
| 217 | json.Unmarshal(data, &arr) | ||
| 218 | |||
| 219 | var subID string | ||
| 220 | json.Unmarshal(arr[1], &subID) | ||
| 221 | |||
| 222 | // Send some events | ||
| 223 | for i := 0; i < 3; i++ { | ||
| 224 | event := Event{ | ||
| 225 | ID: "event" + string(rune('0'+i)), | ||
| 226 | PubKey: "pubkey", | ||
| 227 | CreatedAt: time.Now().Unix(), | ||
| 228 | Kind: 1, | ||
| 229 | Tags: Tags{}, | ||
| 230 | Content: "Test event", | ||
| 231 | Sig: "sig", | ||
| 232 | } | ||
| 233 | response, _ := json.Marshal([]interface{}{"EVENT", subID, event}) | ||
| 234 | conn.Write(context.Background(), websocket.MessageText, response) | ||
| 235 | } | ||
| 236 | |||
| 237 | // Send EOSE | ||
| 238 | eose, _ := json.Marshal([]interface{}{"EOSE", subID}) | ||
| 239 | conn.Write(context.Background(), websocket.MessageText, eose) | ||
| 240 | }) | ||
| 241 | defer server.Close() | ||
| 242 | |||
| 243 | url := "ws" + strings.TrimPrefix(server.URL, "http") | ||
| 244 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) | ||
| 245 | defer cancel() | ||
| 246 | |||
| 247 | relay, err := Connect(ctx, url) | ||
| 248 | if err != nil { | ||
| 249 | t.Fatalf("Connect() error = %v", err) | ||
| 250 | } | ||
| 251 | defer relay.Close() | ||
| 252 | |||
| 253 | sub := relay.Fetch(ctx, Filter{Kinds: []int{1}}) | ||
| 254 | |||
| 255 | eventCount := 0 | ||
| 256 | for range sub.Events { | ||
| 257 | eventCount++ | ||
| 258 | } | ||
| 259 | |||
| 260 | if eventCount != 3 { | ||
| 261 | t.Errorf("Received %d events, want 3", eventCount) | ||
| 262 | } | ||
| 263 | if sub.Err != nil { | ||
| 264 | t.Errorf("Subscription.Err = %v, want nil", sub.Err) | ||
| 265 | } | ||
| 266 | } | ||
| 267 | |||
| 268 | func TestRelayClose(t *testing.T) { | ||
| 269 | server := mockRelay(t, func(conn *websocket.Conn) { | ||
| 270 | time.Sleep(100 * time.Millisecond) | ||
| 271 | }) | ||
| 272 | defer server.Close() | ||
| 273 | |||
| 274 | url := "ws" + strings.TrimPrefix(server.URL, "http") | ||
| 275 | ctx := context.Background() | ||
| 276 | |||
| 277 | relay, err := Connect(ctx, url) | ||
| 278 | if err != nil { | ||
| 279 | t.Fatalf("Connect() error = %v", err) | ||
| 280 | } | ||
| 281 | |||
| 282 | if err := relay.Close(); err != nil { | ||
| 283 | t.Errorf("Close() error = %v", err) | ||
| 284 | } | ||
| 285 | |||
| 286 | // Second close should be safe | ||
| 287 | if err := relay.Close(); err != nil { | ||
| 288 | t.Errorf("Second Close() error = %v", err) | ||
| 289 | } | ||
| 290 | } | ||
| 291 | |||
| 292 | func TestSubscriptionClose(t *testing.T) { | ||
| 293 | server := mockRelay(t, func(conn *websocket.Conn) { | ||
| 294 | // Read REQ | ||
| 295 | conn.Read(context.Background()) | ||
| 296 | |||
| 297 | // Wait for CLOSE | ||
| 298 | _, data, err := conn.Read(context.Background()) | ||
| 299 | if err != nil { | ||
| 300 | return | ||
| 301 | } | ||
| 302 | |||
| 303 | var arr []interface{} | ||
| 304 | json.Unmarshal(data, &arr) | ||
| 305 | |||
| 306 | if arr[0] != "CLOSE" { | ||
| 307 | t.Errorf("Expected CLOSE, got %v", arr[0]) | ||
| 308 | } | ||
| 309 | }) | ||
| 310 | defer server.Close() | ||
| 311 | |||
| 312 | url := "ws" + strings.TrimPrefix(server.URL, "http") | ||
| 313 | ctx := context.Background() | ||
| 314 | |||
| 315 | relay, err := Connect(ctx, url) | ||
| 316 | if err != nil { | ||
| 317 | t.Fatalf("Connect() error = %v", err) | ||
| 318 | } | ||
| 319 | defer relay.Close() | ||
| 320 | |||
| 321 | sub := relay.Subscribe(ctx, Filter{Kinds: []int{1}}) | ||
| 322 | |||
| 323 | if err := sub.Close(ctx); err != nil { | ||
| 324 | t.Errorf("Subscription.Close() error = %v", err) | ||
| 325 | } | ||
| 326 | } | ||
diff --git a/internal/nostr/tags.go b/internal/nostr/tags.go new file mode 100644 index 0000000..4fe3d04 --- /dev/null +++ b/internal/nostr/tags.go | |||
| @@ -0,0 +1,64 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | // Tag represents a single Nostr tag, which is an array of strings. | ||
| 4 | // The first element is the tag key, followed by its values. | ||
| 5 | type Tag []string | ||
| 6 | |||
| 7 | // Key returns the tag key (first element), or empty string if tag is empty. | ||
| 8 | func (t Tag) Key() string { | ||
| 9 | if len(t) == 0 { | ||
| 10 | return "" | ||
| 11 | } | ||
| 12 | return t[0] | ||
| 13 | } | ||
| 14 | |||
| 15 | // Value returns the first value (second element), or empty string if not present. | ||
| 16 | func (t Tag) Value() string { | ||
| 17 | if len(t) < 2 { | ||
| 18 | return "" | ||
| 19 | } | ||
| 20 | return t[1] | ||
| 21 | } | ||
| 22 | |||
| 23 | // Tags represents a collection of tags. | ||
| 24 | type Tags []Tag | ||
| 25 | |||
| 26 | // Find returns the first tag matching the given key, or nil if not found. | ||
| 27 | func (tags Tags) Find(key string) Tag { | ||
| 28 | for _, tag := range tags { | ||
| 29 | if tag.Key() == key { | ||
| 30 | return tag | ||
| 31 | } | ||
| 32 | } | ||
| 33 | return nil | ||
| 34 | } | ||
| 35 | |||
| 36 | // FindAll returns all tags matching the given key. | ||
| 37 | func (tags Tags) FindAll(key string) Tags { | ||
| 38 | var result Tags | ||
| 39 | for _, tag := range tags { | ||
| 40 | if tag.Key() == key { | ||
| 41 | result = append(result, tag) | ||
| 42 | } | ||
| 43 | } | ||
| 44 | return result | ||
| 45 | } | ||
| 46 | |||
| 47 | // GetD returns the value of the "d" tag, used for addressable events. | ||
| 48 | func (tags Tags) GetD() string { | ||
| 49 | tag := tags.Find("d") | ||
| 50 | if tag == nil { | ||
| 51 | return "" | ||
| 52 | } | ||
| 53 | return tag.Value() | ||
| 54 | } | ||
| 55 | |||
| 56 | // ContainsValue checks if any tag with the given key contains the specified value. | ||
| 57 | func (tags Tags) ContainsValue(key, value string) bool { | ||
| 58 | for _, tag := range tags { | ||
| 59 | if tag.Key() == key && tag.Value() == value { | ||
| 60 | return true | ||
| 61 | } | ||
| 62 | } | ||
| 63 | return false | ||
| 64 | } | ||
diff --git a/internal/nostr/tags_test.go b/internal/nostr/tags_test.go new file mode 100644 index 0000000..7796606 --- /dev/null +++ b/internal/nostr/tags_test.go | |||
| @@ -0,0 +1,158 @@ | |||
| 1 | package nostr | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "testing" | ||
| 5 | ) | ||
| 6 | |||
| 7 | func TestTagKey(t *testing.T) { | ||
| 8 | tests := []struct { | ||
| 9 | name string | ||
| 10 | tag Tag | ||
| 11 | want string | ||
| 12 | }{ | ||
| 13 | {"empty tag", Tag{}, ""}, | ||
| 14 | {"single element", Tag{"e"}, "e"}, | ||
| 15 | {"multiple elements", Tag{"p", "abc123", "relay"}, "p"}, | ||
| 16 | } | ||
| 17 | |||
| 18 | for _, tt := range tests { | ||
| 19 | t.Run(tt.name, func(t *testing.T) { | ||
| 20 | if got := tt.tag.Key(); got != tt.want { | ||
| 21 | t.Errorf("Tag.Key() = %q, want %q", got, tt.want) | ||
| 22 | } | ||
| 23 | }) | ||
| 24 | } | ||
| 25 | } | ||
| 26 | |||
| 27 | func TestTagValue(t *testing.T) { | ||
| 28 | tests := []struct { | ||
| 29 | name string | ||
| 30 | tag Tag | ||
| 31 | want string | ||
| 32 | }{ | ||
| 33 | {"empty tag", Tag{}, ""}, | ||
| 34 | {"single element", Tag{"e"}, ""}, | ||
| 35 | {"two elements", Tag{"p", "abc123"}, "abc123"}, | ||
| 36 | {"multiple elements", Tag{"e", "eventid", "relay", "marker"}, "eventid"}, | ||
| 37 | } | ||
| 38 | |||
| 39 | for _, tt := range tests { | ||
| 40 | t.Run(tt.name, func(t *testing.T) { | ||
| 41 | if got := tt.tag.Value(); got != tt.want { | ||
| 42 | t.Errorf("Tag.Value() = %q, want %q", got, tt.want) | ||
| 43 | } | ||
| 44 | }) | ||
| 45 | } | ||
| 46 | } | ||
| 47 | |||
| 48 | func TestTagsFind(t *testing.T) { | ||
| 49 | tags := Tags{ | ||
| 50 | {"e", "event1"}, | ||
| 51 | {"p", "pubkey1"}, | ||
| 52 | {"e", "event2"}, | ||
| 53 | {"d", "identifier"}, | ||
| 54 | } | ||
| 55 | |||
| 56 | tests := []struct { | ||
| 57 | name string | ||
| 58 | key string | ||
| 59 | wantNil bool | ||
| 60 | wantVal string | ||
| 61 | }{ | ||
| 62 | {"find first e", "e", false, "event1"}, | ||
| 63 | {"find p", "p", false, "pubkey1"}, | ||
| 64 | {"find d", "d", false, "identifier"}, | ||
| 65 | {"find nonexistent", "x", true, ""}, | ||
| 66 | } | ||
| 67 | |||
| 68 | for _, tt := range tests { | ||
| 69 | t.Run(tt.name, func(t *testing.T) { | ||
| 70 | got := tags.Find(tt.key) | ||
| 71 | if tt.wantNil { | ||
| 72 | if got != nil { | ||
| 73 | t.Errorf("Tags.Find(%q) = %v, want nil", tt.key, got) | ||
| 74 | } | ||
| 75 | } else { | ||
| 76 | if got == nil { | ||
| 77 | t.Errorf("Tags.Find(%q) = nil, want value %q", tt.key, tt.wantVal) | ||
| 78 | } else if got.Value() != tt.wantVal { | ||
| 79 | t.Errorf("Tags.Find(%q).Value() = %q, want %q", tt.key, got.Value(), tt.wantVal) | ||
| 80 | } | ||
| 81 | } | ||
| 82 | }) | ||
| 83 | } | ||
| 84 | } | ||
| 85 | |||
| 86 | func TestTagsFindAll(t *testing.T) { | ||
| 87 | tags := Tags{ | ||
| 88 | {"e", "event1"}, | ||
| 89 | {"p", "pubkey1"}, | ||
| 90 | {"e", "event2"}, | ||
| 91 | {"e", "event3"}, | ||
| 92 | } | ||
| 93 | |||
| 94 | found := tags.FindAll("e") | ||
| 95 | if len(found) != 3 { | ||
| 96 | t.Errorf("Tags.FindAll(\"e\") returned %d tags, want 3", len(found)) | ||
| 97 | } | ||
| 98 | |||
| 99 | found = tags.FindAll("p") | ||
| 100 | if len(found) != 1 { | ||
| 101 | t.Errorf("Tags.FindAll(\"p\") returned %d tags, want 1", len(found)) | ||
| 102 | } | ||
| 103 | |||
| 104 | found = tags.FindAll("x") | ||
| 105 | if len(found) != 0 { | ||
| 106 | t.Errorf("Tags.FindAll(\"x\") returned %d tags, want 0", len(found)) | ||
| 107 | } | ||
| 108 | } | ||
| 109 | |||
| 110 | func TestTagsGetD(t *testing.T) { | ||
| 111 | tests := []struct { | ||
| 112 | name string | ||
| 113 | tags Tags | ||
| 114 | want string | ||
| 115 | }{ | ||
| 116 | {"no d tag", Tags{{"e", "event1"}}, ""}, | ||
| 117 | {"empty d tag", Tags{{"d"}}, ""}, | ||
| 118 | {"d tag present", Tags{{"d", "my-identifier"}}, "my-identifier"}, | ||
| 119 | {"d tag with extras", Tags{{"d", "id", "extra"}}, "id"}, | ||
| 120 | } | ||
| 121 | |||
| 122 | for _, tt := range tests { | ||
| 123 | t.Run(tt.name, func(t *testing.T) { | ||
| 124 | if got := tt.tags.GetD(); got != tt.want { | ||
| 125 | t.Errorf("Tags.GetD() = %q, want %q", got, tt.want) | ||
| 126 | } | ||
| 127 | }) | ||
| 128 | } | ||
| 129 | } | ||
| 130 | |||
| 131 | func TestTagsContainsValue(t *testing.T) { | ||
| 132 | tags := Tags{ | ||
| 133 | {"e", "event1"}, | ||
| 134 | {"p", "pubkey1"}, | ||
| 135 | {"e", "event2"}, | ||
| 136 | } | ||
| 137 | |||
| 138 | tests := []struct { | ||
| 139 | key string | ||
| 140 | value string | ||
| 141 | want bool | ||
| 142 | }{ | ||
| 143 | {"e", "event1", true}, | ||
| 144 | {"e", "event2", true}, | ||
| 145 | {"e", "event3", false}, | ||
| 146 | {"p", "pubkey1", true}, | ||
| 147 | {"p", "pubkey2", false}, | ||
| 148 | {"x", "anything", false}, | ||
| 149 | } | ||
| 150 | |||
| 151 | for _, tt := range tests { | ||
| 152 | t.Run(tt.key+"="+tt.value, func(t *testing.T) { | ||
| 153 | if got := tags.ContainsValue(tt.key, tt.value); got != tt.want { | ||
| 154 | t.Errorf("Tags.ContainsValue(%q, %q) = %v, want %v", tt.key, tt.value, got, tt.want) | ||
| 155 | } | ||
| 156 | }) | ||
| 157 | } | ||
| 158 | } | ||
diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go new file mode 100644 index 0000000..fe937c8 --- /dev/null +++ b/internal/websocket/websocket.go | |||
| @@ -0,0 +1,297 @@ | |||
| 1 | package websocket | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "bufio" | ||
| 5 | "context" | ||
| 6 | "crypto/rand" | ||
| 7 | "crypto/sha1" | ||
| 8 | "crypto/tls" | ||
| 9 | "encoding/base64" | ||
| 10 | "encoding/binary" | ||
| 11 | "fmt" | ||
| 12 | "io" | ||
| 13 | "net" | ||
| 14 | "net/http" | ||
| 15 | "net/url" | ||
| 16 | "strings" | ||
| 17 | "sync" | ||
| 18 | "time" | ||
| 19 | ) | ||
| 20 | |||
| 21 | type MessageType int | ||
| 22 | |||
| 23 | const MessageText MessageType = 1 | ||
| 24 | |||
| 25 | type StatusCode int | ||
| 26 | |||
| 27 | const StatusNormalClosure StatusCode = 1000 | ||
| 28 | |||
| 29 | const ( | ||
| 30 | opText = 0x1 | ||
| 31 | opClose = 0x8 | ||
| 32 | opPing = 0x9 | ||
| 33 | opPong = 0xA | ||
| 34 | ) | ||
| 35 | |||
| 36 | type Conn struct { | ||
| 37 | rwc net.Conn | ||
| 38 | br *bufio.Reader | ||
| 39 | client bool | ||
| 40 | mu sync.Mutex | ||
| 41 | } | ||
| 42 | |||
| 43 | func mask(key [4]byte, data []byte) { | ||
| 44 | for i := range data { | ||
| 45 | data[i] ^= key[i%4] | ||
| 46 | } | ||
| 47 | } | ||
| 48 | |||
| 49 | func (c *Conn) writeFrame(opcode byte, payload []byte) error { | ||
| 50 | c.mu.Lock() | ||
| 51 | defer c.mu.Unlock() | ||
| 52 | |||
| 53 | length := len(payload) | ||
| 54 | header := []byte{0x80 | opcode, 0} // FIN + opcode | ||
| 55 | |||
| 56 | if c.client { | ||
| 57 | header[1] = 0x80 // mask bit | ||
| 58 | } | ||
| 59 | |||
| 60 | switch { | ||
| 61 | case length <= 125: | ||
| 62 | header[1] |= byte(length) | ||
| 63 | case length <= 65535: | ||
| 64 | header[1] |= 126 | ||
| 65 | ext := make([]byte, 2) | ||
| 66 | binary.BigEndian.PutUint16(ext, uint16(length)) | ||
| 67 | header = append(header, ext...) | ||
| 68 | default: | ||
| 69 | header[1] |= 127 | ||
| 70 | ext := make([]byte, 8) | ||
| 71 | binary.BigEndian.PutUint64(ext, uint64(length)) | ||
| 72 | header = append(header, ext...) | ||
| 73 | } | ||
| 74 | |||
| 75 | if c.client { | ||
| 76 | var key [4]byte | ||
| 77 | rand.Read(key[:]) | ||
| 78 | header = append(header, key[:]...) | ||
| 79 | mask(key, payload) | ||
| 80 | } | ||
| 81 | |||
| 82 | if _, err := c.rwc.Write(header); err != nil { | ||
| 83 | return err | ||
| 84 | } | ||
| 85 | _, err := c.rwc.Write(payload) | ||
| 86 | return err | ||
| 87 | } | ||
| 88 | |||
| 89 | func (c *Conn) readFrame() (fin bool, opcode byte, payload []byte, err error) { | ||
| 90 | var hdr [2]byte | ||
| 91 | if _, err = io.ReadFull(c.br, hdr[:]); err != nil { | ||
| 92 | return | ||
| 93 | } | ||
| 94 | |||
| 95 | fin = hdr[0]&0x80 != 0 | ||
| 96 | opcode = hdr[0] & 0x0F | ||
| 97 | masked := hdr[1]&0x80 != 0 | ||
| 98 | length := uint64(hdr[1] & 0x7F) | ||
| 99 | |||
| 100 | switch length { | ||
| 101 | case 126: | ||
| 102 | var ext [2]byte | ||
| 103 | if _, err = io.ReadFull(c.br, ext[:]); err != nil { | ||
| 104 | return | ||
| 105 | } | ||
| 106 | length = uint64(binary.BigEndian.Uint16(ext[:])) | ||
| 107 | case 127: | ||
| 108 | var ext [8]byte | ||
| 109 | if _, err = io.ReadFull(c.br, ext[:]); err != nil { | ||
| 110 | return | ||
| 111 | } | ||
| 112 | length = binary.BigEndian.Uint64(ext[:]) | ||
| 113 | } | ||
| 114 | |||
| 115 | var key [4]byte | ||
| 116 | if masked { | ||
| 117 | if _, err = io.ReadFull(c.br, key[:]); err != nil { | ||
| 118 | return | ||
| 119 | } | ||
| 120 | } | ||
| 121 | |||
| 122 | payload = make([]byte, length) | ||
| 123 | if _, err = io.ReadFull(c.br, payload); err != nil { | ||
| 124 | return | ||
| 125 | } | ||
| 126 | |||
| 127 | if masked { | ||
| 128 | mask(key, payload) | ||
| 129 | } | ||
| 130 | return | ||
| 131 | } | ||
| 132 | |||
| 133 | func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { | ||
| 134 | stop := context.AfterFunc(ctx, func() { | ||
| 135 | c.rwc.SetReadDeadline(time.Now()) | ||
| 136 | }) | ||
| 137 | defer stop() | ||
| 138 | |||
| 139 | var buf []byte | ||
| 140 | for { | ||
| 141 | fin, opcode, payload, err := c.readFrame() | ||
| 142 | if err != nil { | ||
| 143 | if ctx.Err() != nil { | ||
| 144 | return 0, nil, ctx.Err() | ||
| 145 | } | ||
| 146 | return 0, nil, err | ||
| 147 | } | ||
| 148 | |||
| 149 | switch opcode { | ||
| 150 | case opPing: | ||
| 151 | c.writeFrame(opPong, payload) | ||
| 152 | continue | ||
| 153 | case opClose: | ||
| 154 | return 0, nil, fmt.Errorf("websocket: close frame received") | ||
| 155 | case opText, 0x0: // text or continuation | ||
| 156 | buf = append(buf, payload...) | ||
| 157 | if fin { | ||
| 158 | return MessageText, buf, nil | ||
| 159 | } | ||
| 160 | default: | ||
| 161 | buf = append(buf, payload...) | ||
| 162 | if fin { | ||
| 163 | return MessageText, buf, nil | ||
| 164 | } | ||
| 165 | } | ||
| 166 | } | ||
| 167 | } | ||
| 168 | |||
| 169 | func (c *Conn) Write(ctx context.Context, typ MessageType, data []byte) error { | ||
| 170 | return c.writeFrame(byte(typ), data) | ||
| 171 | } | ||
| 172 | |||
| 173 | func (c *Conn) Close(code StatusCode, reason string) error { | ||
| 174 | payload := make([]byte, 2+len(reason)) | ||
| 175 | binary.BigEndian.PutUint16(payload, uint16(code)) | ||
| 176 | copy(payload[2:], reason) | ||
| 177 | c.writeFrame(opClose, payload) | ||
| 178 | return c.rwc.Close() | ||
| 179 | } | ||
| 180 | |||
| 181 | var wsGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" | ||
| 182 | |||
| 183 | func acceptKey(key string) string { | ||
| 184 | h := sha1.New() | ||
| 185 | h.Write([]byte(key)) | ||
| 186 | h.Write([]byte(wsGUID)) | ||
| 187 | return base64.StdEncoding.EncodeToString(h.Sum(nil)) | ||
| 188 | } | ||
| 189 | |||
| 190 | func Dial(ctx context.Context, rawURL string) (*Conn, error) { | ||
| 191 | u, err := url.Parse(rawURL) | ||
| 192 | if err != nil { | ||
| 193 | return nil, err | ||
| 194 | } | ||
| 195 | |||
| 196 | host := u.Hostname() | ||
| 197 | port := u.Port() | ||
| 198 | useTLS := u.Scheme == "wss" | ||
| 199 | |||
| 200 | if port == "" { | ||
| 201 | if useTLS { | ||
| 202 | port = "443" | ||
| 203 | } else { | ||
| 204 | port = "80" | ||
| 205 | } | ||
| 206 | } | ||
| 207 | |||
| 208 | addr := net.JoinHostPort(host, port) | ||
| 209 | rwc, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) | ||
| 210 | if err != nil { | ||
| 211 | return nil, err | ||
| 212 | } | ||
| 213 | |||
| 214 | if useTLS { | ||
| 215 | tc := tls.Client(rwc, &tls.Config{ServerName: host}) | ||
| 216 | if err := tc.HandshakeContext(ctx); err != nil { | ||
| 217 | rwc.Close() | ||
| 218 | return nil, err | ||
| 219 | } | ||
| 220 | rwc = tc | ||
| 221 | } | ||
| 222 | |||
| 223 | br := bufio.NewReader(rwc) | ||
| 224 | |||
| 225 | var keyBytes [16]byte | ||
| 226 | rand.Read(keyBytes[:]) | ||
| 227 | key := base64.StdEncoding.EncodeToString(keyBytes[:]) | ||
| 228 | |||
| 229 | path := u.RequestURI() | ||
| 230 | reqStr := "GET " + path + " HTTP/1.1\r\n" + | ||
| 231 | "Host: " + host + "\r\n" + | ||
| 232 | "Upgrade: websocket\r\n" + | ||
| 233 | "Connection: Upgrade\r\n" + | ||
| 234 | "Sec-WebSocket-Key: " + key + "\r\n" + | ||
| 235 | "Sec-WebSocket-Version: 13\r\n\r\n" | ||
| 236 | |||
| 237 | if _, err := rwc.Write([]byte(reqStr)); err != nil { | ||
| 238 | rwc.Close() | ||
| 239 | return nil, err | ||
| 240 | } | ||
| 241 | |||
| 242 | req := &http.Request{Method: "GET"} | ||
| 243 | resp, err := http.ReadResponse(br, req) | ||
| 244 | if err != nil { | ||
| 245 | rwc.Close() | ||
| 246 | return nil, err | ||
| 247 | } | ||
| 248 | resp.Body.Close() | ||
| 249 | |||
| 250 | if resp.StatusCode != 101 { | ||
| 251 | rwc.Close() | ||
| 252 | return nil, fmt.Errorf("websocket: bad handshake status %d", resp.StatusCode) | ||
| 253 | } | ||
| 254 | |||
| 255 | got := resp.Header.Get("Sec-WebSocket-Accept") | ||
| 256 | want := acceptKey(key) | ||
| 257 | if got != want { | ||
| 258 | rwc.Close() | ||
| 259 | return nil, fmt.Errorf("websocket: invalid Sec-WebSocket-Accept") | ||
| 260 | } | ||
| 261 | |||
| 262 | return &Conn{rwc: rwc, br: br, client: true}, nil | ||
| 263 | } | ||
| 264 | |||
| 265 | func Accept(w http.ResponseWriter, r *http.Request) (*Conn, error) { | ||
| 266 | if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { | ||
| 267 | return nil, fmt.Errorf("websocket: missing Upgrade header") | ||
| 268 | } | ||
| 269 | |||
| 270 | key := r.Header.Get("Sec-WebSocket-Key") | ||
| 271 | if key == "" { | ||
| 272 | return nil, fmt.Errorf("websocket: missing Sec-WebSocket-Key") | ||
| 273 | } | ||
| 274 | |||
| 275 | hj, ok := w.(http.Hijacker) | ||
| 276 | if !ok { | ||
| 277 | return nil, fmt.Errorf("websocket: response does not support hijacking") | ||
| 278 | } | ||
| 279 | |||
| 280 | rwc, brw, err := hj.Hijack() | ||
| 281 | if err != nil { | ||
| 282 | return nil, err | ||
| 283 | } | ||
| 284 | |||
| 285 | accept := acceptKey(key) | ||
| 286 | respStr := "HTTP/1.1 101 Switching Protocols\r\n" + | ||
| 287 | "Upgrade: websocket\r\n" + | ||
| 288 | "Connection: Upgrade\r\n" + | ||
| 289 | "Sec-WebSocket-Accept: " + accept + "\r\n\r\n" | ||
| 290 | |||
| 291 | if _, err := rwc.Write([]byte(respStr)); err != nil { | ||
| 292 | rwc.Close() | ||
| 293 | return nil, err | ||
| 294 | } | ||
| 295 | |||
| 296 | return &Conn{rwc: rwc, br: brw.Reader, client: false}, nil | ||
| 297 | } | ||
