From 581ceecbf046f99b39885c74e2780a5320e5b15e Mon Sep 17 00:00:00 2001 From: bndw Date: Fri, 13 Feb 2026 17:35:32 -0800 Subject: feat: add Nostr protocol implementation (internal/nostr, internal/websocket) --- internal/nostr/bech32.go | 162 ++++++++++++++++ internal/nostr/bech32_test.go | 139 ++++++++++++++ internal/nostr/envelope.go | 262 +++++++++++++++++++++++++ internal/nostr/envelope_test.go | 416 ++++++++++++++++++++++++++++++++++++++++ internal/nostr/event.go | 72 +++++++ internal/nostr/event_test.go | 194 +++++++++++++++++++ internal/nostr/example_test.go | 85 ++++++++ internal/nostr/filter.go | 224 ++++++++++++++++++++++ internal/nostr/filter_test.go | 415 +++++++++++++++++++++++++++++++++++++++ internal/nostr/keys.go | 217 +++++++++++++++++++++ internal/nostr/keys_test.go | 333 ++++++++++++++++++++++++++++++++ internal/nostr/kinds.go | 51 +++++ internal/nostr/kinds_test.go | 128 +++++++++++++ internal/nostr/relay.go | 305 +++++++++++++++++++++++++++++ internal/nostr/relay_test.go | 326 +++++++++++++++++++++++++++++++ internal/nostr/tags.go | 64 +++++++ internal/nostr/tags_test.go | 158 +++++++++++++++ internal/websocket/websocket.go | 297 ++++++++++++++++++++++++++++ 18 files changed, 3848 insertions(+) create mode 100644 internal/nostr/bech32.go create mode 100644 internal/nostr/bech32_test.go create mode 100644 internal/nostr/envelope.go create mode 100644 internal/nostr/envelope_test.go create mode 100644 internal/nostr/event.go create mode 100644 internal/nostr/event_test.go create mode 100644 internal/nostr/example_test.go create mode 100644 internal/nostr/filter.go create mode 100644 internal/nostr/filter_test.go create mode 100644 internal/nostr/keys.go create mode 100644 internal/nostr/keys_test.go create mode 100644 internal/nostr/kinds.go create mode 100644 internal/nostr/kinds_test.go create mode 100644 internal/nostr/relay.go create mode 100644 internal/nostr/relay_test.go create mode 100644 internal/nostr/tags.go create mode 100644 internal/nostr/tags_test.go create mode 100644 internal/websocket/websocket.go diff --git a/internal/nostr/bech32.go b/internal/nostr/bech32.go new file mode 100644 index 0000000..c8b1293 --- /dev/null +++ b/internal/nostr/bech32.go @@ -0,0 +1,162 @@ +package nostr + +import ( + "fmt" + "strings" +) + +// Bech32 encoding/decoding for NIP-19 (npub, nsec, note, etc.) +// Implements BIP-173 bech32 encoding. + +const bech32Alphabet = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" + +var bech32AlphabetMap [256]int8 + +func init() { + for i := range bech32AlphabetMap { + bech32AlphabetMap[i] = -1 + } + for i, c := range bech32Alphabet { + bech32AlphabetMap[c] = int8(i) + } +} + +// bech32Polymod computes the BCH checksum. +func bech32Polymod(values []int) int { + gen := []int{0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3} + chk := 1 + for _, v := range values { + top := chk >> 25 + chk = (chk&0x1ffffff)<<5 ^ v + for i := 0; i < 5; i++ { + if (top>>i)&1 == 1 { + chk ^= gen[i] + } + } + } + return chk +} + +// bech32HRPExpand expands the human-readable part for checksum computation. +func bech32HRPExpand(hrp string) []int { + result := make([]int, len(hrp)*2+1) + for i, c := range hrp { + result[i] = int(c) >> 5 + result[i+len(hrp)+1] = int(c) & 31 + } + return result +} + +// bech32CreateChecksum creates the 6-character checksum. +func bech32CreateChecksum(hrp string, data []int) []int { + values := append(bech32HRPExpand(hrp), data...) + values = append(values, []int{0, 0, 0, 0, 0, 0}...) + polymod := bech32Polymod(values) ^ 1 + checksum := make([]int, 6) + for i := 0; i < 6; i++ { + checksum[i] = (polymod >> (5 * (5 - i))) & 31 + } + return checksum +} + +// bech32VerifyChecksum verifies the checksum of bech32 data. +func bech32VerifyChecksum(hrp string, data []int) bool { + return bech32Polymod(append(bech32HRPExpand(hrp), data...)) == 1 +} + +// convertBits converts between bit groups. +func convertBits(data []byte, fromBits, toBits int, pad bool) ([]int, error) { + acc := 0 + bits := 0 + result := make([]int, 0, len(data)*fromBits/toBits+1) + maxv := (1 << toBits) - 1 + + for _, value := range data { + acc = (acc << fromBits) | int(value) + bits += fromBits + for bits >= toBits { + bits -= toBits + result = append(result, (acc>>bits)&maxv) + } + } + + if pad { + if bits > 0 { + result = append(result, (acc<<(toBits-bits))&maxv) + } + } else if bits >= fromBits || ((acc<<(toBits-bits))&maxv) != 0 { + return nil, fmt.Errorf("invalid padding") + } + + return result, nil +} + +// Bech32Encode encodes data with the given human-readable prefix. +func Bech32Encode(hrp string, data []byte) (string, error) { + values, err := convertBits(data, 8, 5, true) + if err != nil { + return "", err + } + + checksum := bech32CreateChecksum(hrp, values) + combined := append(values, checksum...) + + var result strings.Builder + result.WriteString(hrp) + result.WriteByte('1') + for _, v := range combined { + result.WriteByte(bech32Alphabet[v]) + } + + return result.String(), nil +} + +// Bech32Decode decodes a bech32 string, returning the HRP and data. +func Bech32Decode(s string) (string, []byte, error) { + s = strings.ToLower(s) + + pos := strings.LastIndexByte(s, '1') + if pos < 1 || pos+7 > len(s) { + return "", nil, fmt.Errorf("invalid bech32 string") + } + + hrp := s[:pos] + dataStr := s[pos+1:] + + data := make([]int, len(dataStr)) + for i, c := range dataStr { + val := bech32AlphabetMap[c] + if val == -1 { + return "", nil, fmt.Errorf("invalid character: %c", c) + } + data[i] = int(val) + } + + if !bech32VerifyChecksum(hrp, data) { + return "", nil, fmt.Errorf("invalid checksum") + } + + // Remove checksum + data = data[:len(data)-6] + + // Convert from 5-bit to 8-bit + result, err := convertBits(intSliceToBytes(data), 5, 8, false) + if err != nil { + return "", nil, err + } + + bytes := make([]byte, len(result)) + for i, v := range result { + bytes[i] = byte(v) + } + + return hrp, bytes, nil +} + +func intSliceToBytes(data []int) []byte { + result := make([]byte, len(data)) + for i, v := range data { + result[i] = byte(v) + } + return result +} diff --git a/internal/nostr/bech32_test.go b/internal/nostr/bech32_test.go new file mode 100644 index 0000000..fb1260b --- /dev/null +++ b/internal/nostr/bech32_test.go @@ -0,0 +1,139 @@ +package nostr + +import ( + "bytes" + "encoding/hex" + "testing" +) + +func TestBech32Encode(t *testing.T) { + // Test vector: 32 bytes of data + data, _ := hex.DecodeString("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef") + + encoded, err := Bech32Encode("npub", data) + if err != nil { + t.Fatalf("Bech32Encode() error = %v", err) + } + + if encoded[:5] != "npub1" { + t.Errorf("Encoded string should start with 'npub1', got %s", encoded[:5]) + } + + // Decode it back + hrp, decoded, err := Bech32Decode(encoded) + if err != nil { + t.Fatalf("Bech32Decode() error = %v", err) + } + + if hrp != "npub" { + t.Errorf("HRP = %s, want npub", hrp) + } + + if !bytes.Equal(decoded, data) { + t.Errorf("Round-trip failed: got %x, want %x", decoded, data) + } +} + +func TestBech32EncodeNsec(t *testing.T) { + data, _ := hex.DecodeString("deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + + encoded, err := Bech32Encode("nsec", data) + if err != nil { + t.Fatalf("Bech32Encode() error = %v", err) + } + + if encoded[:5] != "nsec1" { + t.Errorf("Encoded string should start with 'nsec1', got %s", encoded[:5]) + } + + // Decode it back + hrp, decoded, err := Bech32Decode(encoded) + if err != nil { + t.Fatalf("Bech32Decode() error = %v", err) + } + + if hrp != "nsec" { + t.Errorf("HRP = %s, want nsec", hrp) + } + + if !bytes.Equal(decoded, data) { + t.Errorf("Round-trip failed") + } +} + +func TestBech32DecodeErrors(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"no separator", "npubabcdef"}, + {"empty data", "npub1"}, + {"invalid character", "npub1o"}, // 'o' is not in bech32 alphabet + {"invalid checksum", "npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqpqqqqq"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := Bech32Decode(tt.input) + if err == nil { + t.Error("Bech32Decode() expected error, got nil") + } + }) + } +} + +func TestBech32KnownVectors(t *testing.T) { + // Test with known nostr npub/nsec values + // These can be verified with other nostr tools + + // Generate a key and verify round-trip + key, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + npub := key.Npub() + nsec := key.Nsec() + + // Verify npub decodes to public key + hrp, pubBytes, err := Bech32Decode(npub) + if err != nil { + t.Fatalf("Bech32Decode(npub) error = %v", err) + } + if hrp != "npub" { + t.Errorf("npub HRP = %s, want npub", hrp) + } + if hex.EncodeToString(pubBytes) != key.Public() { + t.Error("npub does not decode to correct public key") + } + + // Verify nsec decodes to private key + hrp, privBytes, err := Bech32Decode(nsec) + if err != nil { + t.Fatalf("Bech32Decode(nsec) error = %v", err) + } + if hrp != "nsec" { + t.Errorf("nsec HRP = %s, want nsec", hrp) + } + if hex.EncodeToString(privBytes) != key.Private() { + t.Error("nsec does not decode to correct private key") + } +} + +func TestBech32CaseInsensitive(t *testing.T) { + data, _ := hex.DecodeString("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef") + encoded, _ := Bech32Encode("npub", data) + + // Test uppercase + upper := "NPUB1" + encoded[5:] + hrp, decoded, err := Bech32Decode(upper) + if err != nil { + t.Fatalf("Bech32Decode(uppercase) error = %v", err) + } + if hrp != "npub" { + t.Errorf("HRP = %s, want npub", hrp) + } + if !bytes.Equal(decoded, data) { + t.Error("Uppercase decode failed") + } +} diff --git a/internal/nostr/envelope.go b/internal/nostr/envelope.go new file mode 100644 index 0000000..d395efa --- /dev/null +++ b/internal/nostr/envelope.go @@ -0,0 +1,262 @@ +package nostr + +import ( + "encoding/json" + "fmt" +) + +// Envelope represents a Nostr protocol message. +type Envelope interface { + Label() string + MarshalJSON() ([]byte, error) +} + +// EventEnvelope wraps an event for the EVENT message. +// Used both client→relay (publish) and relay→client (subscription). +type EventEnvelope struct { + SubscriptionID string // Only for relay→client messages + Event *Event +} + +func (e EventEnvelope) Label() string { return "EVENT" } + +func (e EventEnvelope) MarshalJSON() ([]byte, error) { + if e.SubscriptionID != "" { + return json.Marshal([]interface{}{"EVENT", e.SubscriptionID, e.Event}) + } + return json.Marshal([]interface{}{"EVENT", e.Event}) +} + +// ReqEnvelope represents a REQ message (client→relay). +type ReqEnvelope struct { + SubscriptionID string + Filters []Filter +} + +func (e ReqEnvelope) Label() string { return "REQ" } + +func (e ReqEnvelope) MarshalJSON() ([]byte, error) { + arr := make([]interface{}, 2+len(e.Filters)) + arr[0] = "REQ" + arr[1] = e.SubscriptionID + for i, f := range e.Filters { + arr[2+i] = f + } + return json.Marshal(arr) +} + +// CloseEnvelope represents a CLOSE message (client→relay). +type CloseEnvelope struct { + SubscriptionID string +} + +func (e CloseEnvelope) Label() string { return "CLOSE" } + +func (e CloseEnvelope) MarshalJSON() ([]byte, error) { + return json.Marshal([]interface{}{"CLOSE", e.SubscriptionID}) +} + +// OKEnvelope represents an OK message (relay→client). +type OKEnvelope struct { + EventID string + OK bool + Message string +} + +func (e OKEnvelope) Label() string { return "OK" } + +func (e OKEnvelope) MarshalJSON() ([]byte, error) { + return json.Marshal([]interface{}{"OK", e.EventID, e.OK, e.Message}) +} + +// EOSEEnvelope represents an EOSE (End of Stored Events) message (relay→client). +type EOSEEnvelope struct { + SubscriptionID string +} + +func (e EOSEEnvelope) Label() string { return "EOSE" } + +func (e EOSEEnvelope) MarshalJSON() ([]byte, error) { + return json.Marshal([]interface{}{"EOSE", e.SubscriptionID}) +} + +// ClosedEnvelope represents a CLOSED message (relay→client). +type ClosedEnvelope struct { + SubscriptionID string + Message string +} + +func (e ClosedEnvelope) Label() string { return "CLOSED" } + +func (e ClosedEnvelope) MarshalJSON() ([]byte, error) { + return json.Marshal([]interface{}{"CLOSED", e.SubscriptionID, e.Message}) +} + +// NoticeEnvelope represents a NOTICE message (relay→client). +type NoticeEnvelope struct { + Message string +} + +func (e NoticeEnvelope) Label() string { return "NOTICE" } + +func (e NoticeEnvelope) MarshalJSON() ([]byte, error) { + return json.Marshal([]interface{}{"NOTICE", e.Message}) +} + +// ParseEnvelope parses a raw JSON message into the appropriate envelope type. +func ParseEnvelope(data []byte) (Envelope, error) { + var arr []json.RawMessage + if err := json.Unmarshal(data, &arr); err != nil { + return nil, fmt.Errorf("invalid envelope: %w", err) + } + + if len(arr) < 2 { + return nil, fmt.Errorf("envelope too short") + } + + var label string + if err := json.Unmarshal(arr[0], &label); err != nil { + return nil, fmt.Errorf("invalid envelope label: %w", err) + } + + switch label { + case "EVENT": + return parseEventEnvelope(arr) + case "REQ": + return parseReqEnvelope(arr) + case "CLOSE": + return parseCloseEnvelope(arr) + case "OK": + return parseOKEnvelope(arr) + case "EOSE": + return parseEOSEEnvelope(arr) + case "CLOSED": + return parseClosedEnvelope(arr) + case "NOTICE": + return parseNoticeEnvelope(arr) + default: + return nil, fmt.Errorf("unknown envelope type: %s", label) + } +} + +func parseEventEnvelope(arr []json.RawMessage) (*EventEnvelope, error) { + env := &EventEnvelope{} + + if len(arr) == 2 { + // Client→relay: ["EVENT", event] + var event Event + if err := json.Unmarshal(arr[1], &event); err != nil { + return nil, fmt.Errorf("invalid event: %w", err) + } + env.Event = &event + } else if len(arr) == 3 { + // Relay→client: ["EVENT", subscription_id, event] + if err := json.Unmarshal(arr[1], &env.SubscriptionID); err != nil { + return nil, fmt.Errorf("invalid subscription ID: %w", err) + } + var event Event + if err := json.Unmarshal(arr[2], &event); err != nil { + return nil, fmt.Errorf("invalid event: %w", err) + } + env.Event = &event + } else { + return nil, fmt.Errorf("invalid EVENT envelope length: %d", len(arr)) + } + + return env, nil +} + +func parseReqEnvelope(arr []json.RawMessage) (*ReqEnvelope, error) { + if len(arr) < 3 { + return nil, fmt.Errorf("REQ envelope must have at least 3 elements") + } + + env := &ReqEnvelope{} + if err := json.Unmarshal(arr[1], &env.SubscriptionID); err != nil { + return nil, fmt.Errorf("invalid subscription ID: %w", err) + } + + for i := 2; i < len(arr); i++ { + var filter Filter + if err := json.Unmarshal(arr[i], &filter); err != nil { + return nil, fmt.Errorf("invalid filter at index %d: %w", i-2, err) + } + env.Filters = append(env.Filters, filter) + } + + return env, nil +} + +func parseCloseEnvelope(arr []json.RawMessage) (*CloseEnvelope, error) { + if len(arr) != 2 { + return nil, fmt.Errorf("CLOSE envelope must have exactly 2 elements") + } + + env := &CloseEnvelope{} + if err := json.Unmarshal(arr[1], &env.SubscriptionID); err != nil { + return nil, fmt.Errorf("invalid subscription ID: %w", err) + } + + return env, nil +} + +func parseOKEnvelope(arr []json.RawMessage) (*OKEnvelope, error) { + if len(arr) != 4 { + return nil, fmt.Errorf("OK envelope must have exactly 4 elements") + } + + env := &OKEnvelope{} + if err := json.Unmarshal(arr[1], &env.EventID); err != nil { + return nil, fmt.Errorf("invalid event ID: %w", err) + } + if err := json.Unmarshal(arr[2], &env.OK); err != nil { + return nil, fmt.Errorf("invalid OK status: %w", err) + } + if err := json.Unmarshal(arr[3], &env.Message); err != nil { + return nil, fmt.Errorf("invalid message: %w", err) + } + + return env, nil +} + +func parseEOSEEnvelope(arr []json.RawMessage) (*EOSEEnvelope, error) { + if len(arr) != 2 { + return nil, fmt.Errorf("EOSE envelope must have exactly 2 elements") + } + + env := &EOSEEnvelope{} + if err := json.Unmarshal(arr[1], &env.SubscriptionID); err != nil { + return nil, fmt.Errorf("invalid subscription ID: %w", err) + } + + return env, nil +} + +func parseClosedEnvelope(arr []json.RawMessage) (*ClosedEnvelope, error) { + if len(arr) != 3 { + return nil, fmt.Errorf("CLOSED envelope must have exactly 3 elements") + } + + env := &ClosedEnvelope{} + if err := json.Unmarshal(arr[1], &env.SubscriptionID); err != nil { + return nil, fmt.Errorf("invalid subscription ID: %w", err) + } + if err := json.Unmarshal(arr[2], &env.Message); err != nil { + return nil, fmt.Errorf("invalid message: %w", err) + } + + return env, nil +} + +func parseNoticeEnvelope(arr []json.RawMessage) (*NoticeEnvelope, error) { + if len(arr) != 2 { + return nil, fmt.Errorf("NOTICE envelope must have exactly 2 elements") + } + + env := &NoticeEnvelope{} + if err := json.Unmarshal(arr[1], &env.Message); err != nil { + return nil, fmt.Errorf("invalid message: %w", err) + } + + return env, nil +} diff --git a/internal/nostr/envelope_test.go b/internal/nostr/envelope_test.go new file mode 100644 index 0000000..8f79ad5 --- /dev/null +++ b/internal/nostr/envelope_test.go @@ -0,0 +1,416 @@ +package nostr + +import ( + "encoding/json" + "testing" +) + +func TestEventEnvelopeMarshalJSON(t *testing.T) { + event := &Event{ + ID: "abc123", + PubKey: "def456", + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{}, + Content: "Hello", + Sig: "sig789", + } + + t.Run("client to relay", func(t *testing.T) { + env := EventEnvelope{Event: event} + data, err := env.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var arr []json.RawMessage + if err := json.Unmarshal(data, &arr); err != nil { + t.Fatalf("Invalid JSON: %v", err) + } + + if len(arr) != 2 { + t.Errorf("Array length = %d, want 2", len(arr)) + } + + var label string + json.Unmarshal(arr[0], &label) + if label != "EVENT" { + t.Errorf("Label = %s, want EVENT", label) + } + }) + + t.Run("relay to client", func(t *testing.T) { + env := EventEnvelope{SubscriptionID: "sub1", Event: event} + data, err := env.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var arr []json.RawMessage + if err := json.Unmarshal(data, &arr); err != nil { + t.Fatalf("Invalid JSON: %v", err) + } + + if len(arr) != 3 { + t.Errorf("Array length = %d, want 3", len(arr)) + } + }) +} + +func TestReqEnvelopeMarshalJSON(t *testing.T) { + env := ReqEnvelope{ + SubscriptionID: "sub1", + Filters: []Filter{ + {Kinds: []int{1}}, + {Authors: []string{"abc123"}}, + }, + } + + data, err := env.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var arr []json.RawMessage + if err := json.Unmarshal(data, &arr); err != nil { + t.Fatalf("Invalid JSON: %v", err) + } + + if len(arr) != 4 { // ["REQ", "sub1", filter1, filter2] + t.Errorf("Array length = %d, want 4", len(arr)) + } + + var label string + json.Unmarshal(arr[0], &label) + if label != "REQ" { + t.Errorf("Label = %s, want REQ", label) + } + + var subID string + json.Unmarshal(arr[1], &subID) + if subID != "sub1" { + t.Errorf("SubscriptionID = %s, want sub1", subID) + } +} + +func TestCloseEnvelopeMarshalJSON(t *testing.T) { + env := CloseEnvelope{SubscriptionID: "sub1"} + data, err := env.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var arr []interface{} + if err := json.Unmarshal(data, &arr); err != nil { + t.Fatalf("Invalid JSON: %v", err) + } + + if len(arr) != 2 { + t.Errorf("Array length = %d, want 2", len(arr)) + } + if arr[0] != "CLOSE" { + t.Errorf("Label = %v, want CLOSE", arr[0]) + } + if arr[1] != "sub1" { + t.Errorf("SubscriptionID = %v, want sub1", arr[1]) + } +} + +func TestOKEnvelopeMarshalJSON(t *testing.T) { + env := OKEnvelope{ + EventID: "event123", + OK: true, + Message: "accepted", + } + + data, err := env.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var arr []interface{} + if err := json.Unmarshal(data, &arr); err != nil { + t.Fatalf("Invalid JSON: %v", err) + } + + if len(arr) != 4 { + t.Errorf("Array length = %d, want 4", len(arr)) + } + if arr[0] != "OK" { + t.Errorf("Label = %v, want OK", arr[0]) + } + if arr[1] != "event123" { + t.Errorf("EventID = %v, want event123", arr[1]) + } + if arr[2] != true { + t.Errorf("OK = %v, want true", arr[2]) + } + if arr[3] != "accepted" { + t.Errorf("Message = %v, want accepted", arr[3]) + } +} + +func TestEOSEEnvelopeMarshalJSON(t *testing.T) { + env := EOSEEnvelope{SubscriptionID: "sub1"} + data, err := env.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var arr []interface{} + if err := json.Unmarshal(data, &arr); err != nil { + t.Fatalf("Invalid JSON: %v", err) + } + + if len(arr) != 2 { + t.Errorf("Array length = %d, want 2", len(arr)) + } + if arr[0] != "EOSE" { + t.Errorf("Label = %v, want EOSE", arr[0]) + } +} + +func TestClosedEnvelopeMarshalJSON(t *testing.T) { + env := ClosedEnvelope{ + SubscriptionID: "sub1", + Message: "rate limited", + } + + data, err := env.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var arr []interface{} + if err := json.Unmarshal(data, &arr); err != nil { + t.Fatalf("Invalid JSON: %v", err) + } + + if len(arr) != 3 { + t.Errorf("Array length = %d, want 3", len(arr)) + } + if arr[0] != "CLOSED" { + t.Errorf("Label = %v, want CLOSED", arr[0]) + } +} + +func TestNoticeEnvelopeMarshalJSON(t *testing.T) { + env := NoticeEnvelope{Message: "error: rate limited"} + data, err := env.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var arr []interface{} + if err := json.Unmarshal(data, &arr); err != nil { + t.Fatalf("Invalid JSON: %v", err) + } + + if len(arr) != 2 { + t.Errorf("Array length = %d, want 2", len(arr)) + } + if arr[0] != "NOTICE" { + t.Errorf("Label = %v, want NOTICE", arr[0]) + } +} + +func TestParseEnvelopeEvent(t *testing.T) { + t.Run("client to relay", func(t *testing.T) { + data := `["EVENT",{"id":"abc123","pubkey":"def456","created_at":1704067200,"kind":1,"tags":[],"content":"Hello","sig":"sig789"}]` + env, err := ParseEnvelope([]byte(data)) + if err != nil { + t.Fatalf("ParseEnvelope() error = %v", err) + } + + eventEnv, ok := env.(*EventEnvelope) + if !ok { + t.Fatalf("Expected *EventEnvelope, got %T", env) + } + + if eventEnv.SubscriptionID != "" { + t.Errorf("SubscriptionID = %s, want empty", eventEnv.SubscriptionID) + } + if eventEnv.Event.ID != "abc123" { + t.Errorf("Event.ID = %s, want abc123", eventEnv.Event.ID) + } + }) + + t.Run("relay to client", func(t *testing.T) { + data := `["EVENT","sub1",{"id":"abc123","pubkey":"def456","created_at":1704067200,"kind":1,"tags":[],"content":"Hello","sig":"sig789"}]` + env, err := ParseEnvelope([]byte(data)) + if err != nil { + t.Fatalf("ParseEnvelope() error = %v", err) + } + + eventEnv, ok := env.(*EventEnvelope) + if !ok { + t.Fatalf("Expected *EventEnvelope, got %T", env) + } + + if eventEnv.SubscriptionID != "sub1" { + t.Errorf("SubscriptionID = %s, want sub1", eventEnv.SubscriptionID) + } + }) +} + +func TestParseEnvelopeReq(t *testing.T) { + data := `["REQ","sub1",{"kinds":[1]},{"authors":["abc123"]}]` + env, err := ParseEnvelope([]byte(data)) + if err != nil { + t.Fatalf("ParseEnvelope() error = %v", err) + } + + reqEnv, ok := env.(*ReqEnvelope) + if !ok { + t.Fatalf("Expected *ReqEnvelope, got %T", env) + } + + if reqEnv.SubscriptionID != "sub1" { + t.Errorf("SubscriptionID = %s, want sub1", reqEnv.SubscriptionID) + } + if len(reqEnv.Filters) != 2 { + t.Errorf("Filters length = %d, want 2", len(reqEnv.Filters)) + } +} + +func TestParseEnvelopeClose(t *testing.T) { + data := `["CLOSE","sub1"]` + env, err := ParseEnvelope([]byte(data)) + if err != nil { + t.Fatalf("ParseEnvelope() error = %v", err) + } + + closeEnv, ok := env.(*CloseEnvelope) + if !ok { + t.Fatalf("Expected *CloseEnvelope, got %T", env) + } + + if closeEnv.SubscriptionID != "sub1" { + t.Errorf("SubscriptionID = %s, want sub1", closeEnv.SubscriptionID) + } +} + +func TestParseEnvelopeOK(t *testing.T) { + data := `["OK","event123",true,"accepted"]` + env, err := ParseEnvelope([]byte(data)) + if err != nil { + t.Fatalf("ParseEnvelope() error = %v", err) + } + + okEnv, ok := env.(*OKEnvelope) + if !ok { + t.Fatalf("Expected *OKEnvelope, got %T", env) + } + + if okEnv.EventID != "event123" { + t.Errorf("EventID = %s, want event123", okEnv.EventID) + } + if !okEnv.OK { + t.Error("OK = false, want true") + } + if okEnv.Message != "accepted" { + t.Errorf("Message = %s, want accepted", okEnv.Message) + } +} + +func TestParseEnvelopeEOSE(t *testing.T) { + data := `["EOSE","sub1"]` + env, err := ParseEnvelope([]byte(data)) + if err != nil { + t.Fatalf("ParseEnvelope() error = %v", err) + } + + eoseEnv, ok := env.(*EOSEEnvelope) + if !ok { + t.Fatalf("Expected *EOSEEnvelope, got %T", env) + } + + if eoseEnv.SubscriptionID != "sub1" { + t.Errorf("SubscriptionID = %s, want sub1", eoseEnv.SubscriptionID) + } +} + +func TestParseEnvelopeClosed(t *testing.T) { + data := `["CLOSED","sub1","rate limited"]` + env, err := ParseEnvelope([]byte(data)) + if err != nil { + t.Fatalf("ParseEnvelope() error = %v", err) + } + + closedEnv, ok := env.(*ClosedEnvelope) + if !ok { + t.Fatalf("Expected *ClosedEnvelope, got %T", env) + } + + if closedEnv.SubscriptionID != "sub1" { + t.Errorf("SubscriptionID = %s, want sub1", closedEnv.SubscriptionID) + } + if closedEnv.Message != "rate limited" { + t.Errorf("Message = %s, want rate limited", closedEnv.Message) + } +} + +func TestParseEnvelopeNotice(t *testing.T) { + data := `["NOTICE","error: rate limited"]` + env, err := ParseEnvelope([]byte(data)) + if err != nil { + t.Fatalf("ParseEnvelope() error = %v", err) + } + + noticeEnv, ok := env.(*NoticeEnvelope) + if !ok { + t.Fatalf("Expected *NoticeEnvelope, got %T", env) + } + + if noticeEnv.Message != "error: rate limited" { + t.Errorf("Message = %s, want 'error: rate limited'", noticeEnv.Message) + } +} + +func TestParseEnvelopeErrors(t *testing.T) { + tests := []struct { + name string + data string + }{ + {"invalid json", "not json"}, + {"not array", `{"type":"EVENT"}`}, + {"empty array", `[]`}, + {"single element", `["EVENT"]`}, + {"unknown type", `["UNKNOWN","data"]`}, + {"invalid event length", `["EVENT","a","b","c"]`}, + {"invalid ok length", `["OK","id",true]`}, + {"invalid eose length", `["EOSE"]`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseEnvelope([]byte(tt.data)) + if err == nil { + t.Error("ParseEnvelope() expected error, got nil") + } + }) + } +} + +func TestEnvelopeLabel(t *testing.T) { + tests := []struct { + env Envelope + label string + }{ + {EventEnvelope{}, "EVENT"}, + {ReqEnvelope{}, "REQ"}, + {CloseEnvelope{}, "CLOSE"}, + {OKEnvelope{}, "OK"}, + {EOSEEnvelope{}, "EOSE"}, + {ClosedEnvelope{}, "CLOSED"}, + {NoticeEnvelope{}, "NOTICE"}, + } + + for _, tt := range tests { + t.Run(tt.label, func(t *testing.T) { + if got := tt.env.Label(); got != tt.label { + t.Errorf("Label() = %s, want %s", got, tt.label) + } + }) + } +} diff --git a/internal/nostr/event.go b/internal/nostr/event.go new file mode 100644 index 0000000..a8156bb --- /dev/null +++ b/internal/nostr/event.go @@ -0,0 +1,72 @@ +package nostr + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" +) + +// Event represents a Nostr event as defined in NIP-01. +type Event struct { + ID string `json:"id"` + PubKey string `json:"pubkey"` + CreatedAt int64 `json:"created_at"` + Kind int `json:"kind"` + Tags Tags `json:"tags"` + Content string `json:"content"` + Sig string `json:"sig"` +} + +// Serialize returns the canonical JSON serialization of the event for ID computation. +// Format: [0, "pubkey", created_at, kind, tags, "content"] +func (e *Event) Serialize() []byte { + // Use json.Marshal for proper escaping of content and tags + arr := []interface{}{ + 0, + e.PubKey, + e.CreatedAt, + e.Kind, + e.Tags, + e.Content, + } + data, _ := json.Marshal(arr) + return data +} + +// ComputeID calculates the SHA256 hash of the serialized event. +// Returns the 64-character hex-encoded ID. +func (e *Event) ComputeID() string { + serialized := e.Serialize() + hash := sha256.Sum256(serialized) + return hex.EncodeToString(hash[:]) +} + +// SetID computes and sets the event ID. +func (e *Event) SetID() { + e.ID = e.ComputeID() +} + +// CheckID verifies that the event ID matches the computed ID. +func (e *Event) CheckID() bool { + return e.ID == e.ComputeID() +} + +// MarshalJSON implements json.Marshaler with empty tags as [] instead of null. +func (e Event) MarshalJSON() ([]byte, error) { + type eventAlias Event + ea := eventAlias(e) + if ea.Tags == nil { + ea.Tags = Tags{} + } + return json.Marshal(ea) +} + +// String returns a JSON representation of the event for debugging. +func (e *Event) String() string { + data, err := json.MarshalIndent(e, "", " ") + if err != nil { + return fmt.Sprintf("", err) + } + return string(data) +} diff --git a/internal/nostr/event_test.go b/internal/nostr/event_test.go new file mode 100644 index 0000000..eff4103 --- /dev/null +++ b/internal/nostr/event_test.go @@ -0,0 +1,194 @@ +package nostr + +import ( + "encoding/json" + "testing" +) + +func TestEventSerialize(t *testing.T) { + event := &Event{ + PubKey: "7e7e9c42a91bfef19fa929e5fda1b72e0ebc1a4c1141673e2794234d86addf4e", + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{{"e", "abc123"}, {"p", "def456"}}, + Content: "Hello, Nostr!", + } + + serialized := event.Serialize() + + // Parse the JSON to verify structure + var arr []interface{} + if err := json.Unmarshal(serialized, &arr); err != nil { + t.Fatalf("Serialize() produced invalid JSON: %v", err) + } + + if len(arr) != 6 { + t.Fatalf("Serialized array has %d elements, want 6", len(arr)) + } + + // Check each element + if arr[0].(float64) != 0 { + t.Errorf("arr[0] = %v, want 0", arr[0]) + } + if arr[1].(string) != event.PubKey { + t.Errorf("arr[1] = %v, want %s", arr[1], event.PubKey) + } + if int64(arr[2].(float64)) != event.CreatedAt { + t.Errorf("arr[2] = %v, want %d", arr[2], event.CreatedAt) + } + if int(arr[3].(float64)) != event.Kind { + t.Errorf("arr[3] = %v, want %d", arr[3], event.Kind) + } + if arr[5].(string) != event.Content { + t.Errorf("arr[5] = %v, want %s", arr[5], event.Content) + } +} + +func TestEventComputeID(t *testing.T) { + // Test with a known event (you can verify with other implementations) + event := &Event{ + PubKey: "7e7e9c42a91bfef19fa929e5fda1b72e0ebc1a4c1141673e2794234d86addf4e", + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{}, + Content: "Hello, Nostr!", + } + + id := event.ComputeID() + + // ID should be 64 hex characters + if len(id) != 64 { + t.Errorf("ComputeID() returned ID of length %d, want 64", len(id)) + } + + // Verify it's valid hex + for _, c := range id { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("ComputeID() returned invalid hex character: %c", c) + } + } + + // Verify consistency + id2 := event.ComputeID() + if id != id2 { + t.Errorf("ComputeID() is not consistent: %s != %s", id, id2) + } +} + +func TestEventSetID(t *testing.T) { + event := &Event{ + PubKey: "7e7e9c42a91bfef19fa929e5fda1b72e0ebc1a4c1141673e2794234d86addf4e", + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{}, + Content: "Test", + } + + event.SetID() + if event.ID == "" { + t.Error("SetID() did not set ID") + } + if !event.CheckID() { + t.Error("CheckID() returned false after SetID()") + } +} + +func TestEventCheckID(t *testing.T) { + event := &Event{ + PubKey: "7e7e9c42a91bfef19fa929e5fda1b72e0ebc1a4c1141673e2794234d86addf4e", + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{}, + Content: "Test", + } + + event.SetID() + + if !event.CheckID() { + t.Error("CheckID() returned false for valid ID") + } + + // Corrupt the ID + event.ID = "0000000000000000000000000000000000000000000000000000000000000000" + if event.CheckID() { + t.Error("CheckID() returned true for invalid ID") + } +} + +func TestEventMarshalJSON(t *testing.T) { + event := Event{ + ID: "abc123", + PubKey: "def456", + CreatedAt: 1704067200, + Kind: 1, + Tags: nil, // nil tags + Content: "Test", + Sig: "sig789", + } + + data, err := json.Marshal(event) + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + // Verify tags is [] not null + var m map[string]interface{} + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + tags, ok := m["tags"] + if !ok { + t.Error("tags field missing from JSON") + } + if tags == nil { + t.Error("tags is null, want []") + } + if arr, ok := tags.([]interface{}); !ok || len(arr) != 0 { + t.Errorf("tags = %v, want []", tags) + } +} + +func TestEventJSONRoundTrip(t *testing.T) { + original := Event{ + ID: "abc123def456", + PubKey: "pubkey123", + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{{"e", "event1"}, {"p", "pubkey1", "relay"}}, + Content: "Hello with \"quotes\" and \n newlines", + Sig: "signature123", + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var decoded Event + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if decoded.ID != original.ID { + t.Errorf("ID mismatch: %s != %s", decoded.ID, original.ID) + } + if decoded.PubKey != original.PubKey { + t.Errorf("PubKey mismatch: %s != %s", decoded.PubKey, original.PubKey) + } + if decoded.CreatedAt != original.CreatedAt { + t.Errorf("CreatedAt mismatch: %d != %d", decoded.CreatedAt, original.CreatedAt) + } + if decoded.Kind != original.Kind { + t.Errorf("Kind mismatch: %d != %d", decoded.Kind, original.Kind) + } + if decoded.Content != original.Content { + t.Errorf("Content mismatch: %s != %s", decoded.Content, original.Content) + } + if decoded.Sig != original.Sig { + t.Errorf("Sig mismatch: %s != %s", decoded.Sig, original.Sig) + } + if len(decoded.Tags) != len(original.Tags) { + t.Errorf("Tags length mismatch: %d != %d", len(decoded.Tags), len(original.Tags)) + } +} diff --git a/internal/nostr/example_test.go b/internal/nostr/example_test.go new file mode 100644 index 0000000..80acd21 --- /dev/null +++ b/internal/nostr/example_test.go @@ -0,0 +1,85 @@ +package nostr_test + +import ( + "context" + "fmt" + "time" + + "northwest.io/nostr-grpc/internal/nostr" +) + +// Example_basic demonstrates basic usage of the nostr library. +func Example_basic() { + // Generate a new key pair + key, err := nostr.GenerateKey() + if err != nil { + fmt.Printf("Failed to generate key: %v\n", err) + return + } + + fmt.Printf("Public key (hex): %s...\n", key.Public()[:16]) + fmt.Printf("Public key (npub): %s...\n", key.Npub()[:20]) + + // Create an event + event := &nostr.Event{ + CreatedAt: time.Now().Unix(), + Kind: nostr.KindTextNote, + Tags: nostr.Tags{{"t", "test"}}, + Content: "Hello from nostr-go!", + } + + // Sign the event + if err := key.Sign(event); err != nil { + fmt.Printf("Failed to sign event: %v\n", err) + return + } + + // Verify the signature + if event.Verify() { + fmt.Println("Event signature verified!") + } + + // Create a filter to match our event + filter := nostr.Filter{ + Kinds: []int{nostr.KindTextNote}, + Authors: []string{key.Public()[:8]}, // Prefix matching + } + + if filter.Matches(event) { + fmt.Println("Filter matches the event!") + } +} + +// ExampleRelay demonstrates connecting to a relay (requires network). +// This is a documentation example - run with: go test -v -run ExampleRelay +func ExampleRelay() { + ctx := context.Background() + + // Connect to a public relay + relay, err := nostr.Connect(ctx, "wss://relay.damus.io") + if err != nil { + fmt.Printf("Failed to connect: %v\n", err) + return + } + defer relay.Close() + + fmt.Println("Connected to relay!") + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + // Fetch recent text notes (closes on EOSE) + since := time.Now().Add(-1 * time.Hour).Unix() + sub := relay.Fetch(ctx, nostr.Filter{ + Kinds: []int{nostr.KindTextNote}, + Since: &since, + Limit: 5, + }) + + eventCount := 0 + for event := range sub.Events { + eventCount++ + fmt.Printf("Received event from %s...\n", event.PubKey[:8]) + } + fmt.Printf("Received %d events\n", eventCount) +} diff --git a/internal/nostr/filter.go b/internal/nostr/filter.go new file mode 100644 index 0000000..dde04a5 --- /dev/null +++ b/internal/nostr/filter.go @@ -0,0 +1,224 @@ +package nostr + +import ( + "encoding/json" + "strings" +) + +// Filter represents a subscription filter as defined in NIP-01. +type Filter struct { + IDs []string `json:"ids,omitempty"` + Kinds []int `json:"kinds,omitempty"` + Authors []string `json:"authors,omitempty"` + Tags map[string][]string `json:"-"` // Custom marshaling for #e, #p, etc. + Since *int64 `json:"since,omitempty"` + Until *int64 `json:"until,omitempty"` + Limit int `json:"limit,omitempty"` +} + +// MarshalJSON implements json.Marshaler for Filter. +// Converts Tags map to #e, #p format. +func (f Filter) MarshalJSON() ([]byte, error) { + // Create a map for custom marshaling + m := make(map[string]interface{}) + + if len(f.IDs) > 0 { + m["ids"] = f.IDs + } + if len(f.Kinds) > 0 { + m["kinds"] = f.Kinds + } + if len(f.Authors) > 0 { + m["authors"] = f.Authors + } + if f.Since != nil { + m["since"] = *f.Since + } + if f.Until != nil { + m["until"] = *f.Until + } + if f.Limit > 0 { + m["limit"] = f.Limit + } + + // Add tag filters with # prefix + for key, values := range f.Tags { + if len(values) > 0 { + m["#"+key] = values + } + } + + return json.Marshal(m) +} + +// UnmarshalJSON implements json.Unmarshaler for Filter. +// Extracts #e, #p format into Tags map. +func (f *Filter) UnmarshalJSON(data []byte) error { + // First unmarshal into a raw map + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Extract known fields + if v, ok := raw["ids"]; ok { + if err := json.Unmarshal(v, &f.IDs); err != nil { + return err + } + } + if v, ok := raw["kinds"]; ok { + if err := json.Unmarshal(v, &f.Kinds); err != nil { + return err + } + } + if v, ok := raw["authors"]; ok { + if err := json.Unmarshal(v, &f.Authors); err != nil { + return err + } + } + if v, ok := raw["since"]; ok { + var since int64 + if err := json.Unmarshal(v, &since); err != nil { + return err + } + f.Since = &since + } + if v, ok := raw["until"]; ok { + var until int64 + if err := json.Unmarshal(v, &until); err != nil { + return err + } + f.Until = &until + } + if v, ok := raw["limit"]; ok { + if err := json.Unmarshal(v, &f.Limit); err != nil { + return err + } + } + + // Extract tag filters (fields starting with #) + f.Tags = make(map[string][]string) + for key, value := range raw { + if strings.HasPrefix(key, "#") { + tagKey := strings.TrimPrefix(key, "#") + var values []string + if err := json.Unmarshal(value, &values); err != nil { + return err + } + f.Tags[tagKey] = values + } + } + + return nil +} + +// Matches checks if an event matches this filter. +func (f *Filter) Matches(event *Event) bool { + // Check IDs (prefix match) + if len(f.IDs) > 0 { + found := false + for _, id := range f.IDs { + if strings.HasPrefix(event.ID, id) { + found = true + break + } + } + if !found { + return false + } + } + + // Check authors (prefix match) + if len(f.Authors) > 0 { + found := false + for _, author := range f.Authors { + if strings.HasPrefix(event.PubKey, author) { + found = true + break + } + } + if !found { + return false + } + } + + // Check kinds + if len(f.Kinds) > 0 { + found := false + for _, kind := range f.Kinds { + if event.Kind == kind { + found = true + break + } + } + if !found { + return false + } + } + + // Check since + if f.Since != nil && event.CreatedAt < *f.Since { + return false + } + + // Check until + if f.Until != nil && event.CreatedAt > *f.Until { + return false + } + + // Check tag filters + for tagKey, values := range f.Tags { + if len(values) == 0 { + continue + } + found := false + for _, val := range values { + if event.Tags.ContainsValue(tagKey, val) { + found = true + break + } + } + if !found { + return false + } + } + + return true +} + +// Clone creates a deep copy of the filter. +func (f *Filter) Clone() *Filter { + clone := &Filter{ + Limit: f.Limit, + } + + if f.IDs != nil { + clone.IDs = make([]string, len(f.IDs)) + copy(clone.IDs, f.IDs) + } + if f.Kinds != nil { + clone.Kinds = make([]int, len(f.Kinds)) + copy(clone.Kinds, f.Kinds) + } + if f.Authors != nil { + clone.Authors = make([]string, len(f.Authors)) + copy(clone.Authors, f.Authors) + } + if f.Since != nil { + since := *f.Since + clone.Since = &since + } + if f.Until != nil { + until := *f.Until + clone.Until = &until + } + if f.Tags != nil { + clone.Tags = make(map[string][]string) + for k, v := range f.Tags { + clone.Tags[k] = make([]string, len(v)) + copy(clone.Tags[k], v) + } + } + + return clone +} diff --git a/internal/nostr/filter_test.go b/internal/nostr/filter_test.go new file mode 100644 index 0000000..ebe2b1d --- /dev/null +++ b/internal/nostr/filter_test.go @@ -0,0 +1,415 @@ +package nostr + +import ( + "encoding/json" + "testing" +) + +func TestFilterMarshalJSON(t *testing.T) { + since := int64(1704067200) + until := int64(1704153600) + + filter := Filter{ + IDs: []string{"abc123"}, + Kinds: []int{1, 7}, + Authors: []string{"def456"}, + Tags: map[string][]string{ + "e": {"event1", "event2"}, + "p": {"pubkey1"}, + }, + Since: &since, + Until: &until, + Limit: 100, + } + + data, err := filter.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + // Parse and check structure + var m map[string]interface{} + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + // Check regular fields + if _, ok := m["ids"]; !ok { + t.Error("ids field missing") + } + if _, ok := m["kinds"]; !ok { + t.Error("kinds field missing") + } + if _, ok := m["authors"]; !ok { + t.Error("authors field missing") + } + if _, ok := m["since"]; !ok { + t.Error("since field missing") + } + if _, ok := m["until"]; !ok { + t.Error("until field missing") + } + if _, ok := m["limit"]; !ok { + t.Error("limit field missing") + } + + // Check tag filters with # prefix + if _, ok := m["#e"]; !ok { + t.Error("#e field missing") + } + if _, ok := m["#p"]; !ok { + t.Error("#p field missing") + } +} + +func TestFilterMarshalJSONOmitsEmpty(t *testing.T) { + filter := Filter{ + Kinds: []int{1}, + } + + data, err := filter.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var m map[string]interface{} + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if _, ok := m["ids"]; ok { + t.Error("empty ids should be omitted") + } + if _, ok := m["authors"]; ok { + t.Error("empty authors should be omitted") + } + if _, ok := m["since"]; ok { + t.Error("nil since should be omitted") + } + if _, ok := m["until"]; ok { + t.Error("nil until should be omitted") + } + if _, ok := m["limit"]; ok { + t.Error("zero limit should be omitted") + } +} + +func TestFilterUnmarshalJSON(t *testing.T) { + jsonData := `{ + "ids": ["abc123"], + "kinds": [1, 7], + "authors": ["def456"], + "#e": ["event1", "event2"], + "#p": ["pubkey1"], + "since": 1704067200, + "until": 1704153600, + "limit": 100 + }` + + var filter Filter + if err := json.Unmarshal([]byte(jsonData), &filter); err != nil { + t.Fatalf("UnmarshalJSON() error = %v", err) + } + + if len(filter.IDs) != 1 || filter.IDs[0] != "abc123" { + t.Errorf("IDs = %v, want [abc123]", filter.IDs) + } + if len(filter.Kinds) != 2 { + t.Errorf("Kinds length = %d, want 2", len(filter.Kinds)) + } + if len(filter.Authors) != 1 || filter.Authors[0] != "def456" { + t.Errorf("Authors = %v, want [def456]", filter.Authors) + } + if filter.Since == nil || *filter.Since != 1704067200 { + t.Errorf("Since = %v, want 1704067200", filter.Since) + } + if filter.Until == nil || *filter.Until != 1704153600 { + t.Errorf("Until = %v, want 1704153600", filter.Until) + } + if filter.Limit != 100 { + t.Errorf("Limit = %d, want 100", filter.Limit) + } + + // Check tag filters + if len(filter.Tags["e"]) != 2 { + t.Errorf("Tags[e] length = %d, want 2", len(filter.Tags["e"])) + } + if len(filter.Tags["p"]) != 1 { + t.Errorf("Tags[p] length = %d, want 1", len(filter.Tags["p"])) + } +} + +func TestFilterMatchesIDs(t *testing.T) { + filter := Filter{ + IDs: []string{"abc", "def456"}, + } + + tests := []struct { + id string + want bool + }{ + {"abc123", true}, // matches prefix "abc" + {"abcdef", true}, // matches prefix "abc" + {"def456", true}, // exact match + {"def456xyz", true}, // matches prefix "def456" + {"xyz789", false}, // no match + {"ab", false}, // "ab" doesn't start with "abc" + } + + for _, tt := range tests { + event := &Event{ID: tt.id} + if got := filter.Matches(event); got != tt.want { + t.Errorf("Matches() with ID %s = %v, want %v", tt.id, got, tt.want) + } + } +} + +func TestFilterMatchesAuthors(t *testing.T) { + filter := Filter{ + Authors: []string{"pubkey1", "pubkey2"}, + } + + tests := []struct { + pubkey string + want bool + }{ + {"pubkey1", true}, + {"pubkey1abc", true}, // Prefix match + {"pubkey2", true}, + {"pubkey3", false}, + } + + for _, tt := range tests { + event := &Event{PubKey: tt.pubkey} + if got := filter.Matches(event); got != tt.want { + t.Errorf("Matches() with PubKey %s = %v, want %v", tt.pubkey, got, tt.want) + } + } +} + +func TestFilterMatchesKinds(t *testing.T) { + filter := Filter{ + Kinds: []int{1, 7}, + } + + tests := []struct { + kind int + want bool + }{ + {1, true}, + {7, true}, + {0, false}, + {4, false}, + } + + for _, tt := range tests { + event := &Event{Kind: tt.kind} + if got := filter.Matches(event); got != tt.want { + t.Errorf("Matches() with Kind %d = %v, want %v", tt.kind, got, tt.want) + } + } +} + +func TestFilterMatchesSince(t *testing.T) { + since := int64(1704067200) + filter := Filter{ + Since: &since, + } + + tests := []struct { + createdAt int64 + want bool + }{ + {1704067200, true}, // Equal + {1704067201, true}, // After + {1704067199, false}, // Before + } + + for _, tt := range tests { + event := &Event{CreatedAt: tt.createdAt} + if got := filter.Matches(event); got != tt.want { + t.Errorf("Matches() with CreatedAt %d = %v, want %v", tt.createdAt, got, tt.want) + } + } +} + +func TestFilterMatchesUntil(t *testing.T) { + until := int64(1704067200) + filter := Filter{ + Until: &until, + } + + tests := []struct { + createdAt int64 + want bool + }{ + {1704067200, true}, // Equal + {1704067199, true}, // Before + {1704067201, false}, // After + } + + for _, tt := range tests { + event := &Event{CreatedAt: tt.createdAt} + if got := filter.Matches(event); got != tt.want { + t.Errorf("Matches() with CreatedAt %d = %v, want %v", tt.createdAt, got, tt.want) + } + } +} + +func TestFilterMatchesTags(t *testing.T) { + filter := Filter{ + Tags: map[string][]string{ + "e": {"event1"}, + "p": {"pubkey1", "pubkey2"}, + }, + } + + tests := []struct { + name string + tags Tags + want bool + }{ + { + name: "matches all", + tags: Tags{{"e", "event1"}, {"p", "pubkey1"}}, + want: true, + }, + { + name: "matches with different p", + tags: Tags{{"e", "event1"}, {"p", "pubkey2"}}, + want: true, + }, + { + name: "missing e tag", + tags: Tags{{"p", "pubkey1"}}, + want: false, + }, + { + name: "wrong e value", + tags: Tags{{"e", "event2"}, {"p", "pubkey1"}}, + want: false, + }, + { + name: "extra tags ok", + tags: Tags{{"e", "event1"}, {"p", "pubkey1"}, {"t", "test"}}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := &Event{Tags: tt.tags} + if got := filter.Matches(event); got != tt.want { + t.Errorf("Matches() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFilterMatchesEmpty(t *testing.T) { + // Empty filter matches everything + filter := Filter{} + event := &Event{ + ID: "abc123", + PubKey: "pubkey1", + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{{"e", "event1"}}, + Content: "test", + } + + if !filter.Matches(event) { + t.Error("Empty filter should match all events") + } +} + +func TestFilterClone(t *testing.T) { + since := int64(1704067200) + until := int64(1704153600) + + original := &Filter{ + IDs: []string{"id1", "id2"}, + Kinds: []int{1, 7}, + Authors: []string{"author1"}, + Tags: map[string][]string{ + "e": {"event1"}, + }, + Since: &since, + Until: &until, + Limit: 100, + } + + clone := original.Clone() + + // Modify original + original.IDs[0] = "modified" + original.Kinds[0] = 999 + original.Authors[0] = "modified" + original.Tags["e"][0] = "modified" + *original.Since = 0 + *original.Until = 0 + original.Limit = 0 + + // Clone should be unchanged + if clone.IDs[0] != "id1" { + t.Error("Clone IDs was modified") + } + if clone.Kinds[0] != 1 { + t.Error("Clone Kinds was modified") + } + if clone.Authors[0] != "author1" { + t.Error("Clone Authors was modified") + } + if clone.Tags["e"][0] != "event1" { + t.Error("Clone Tags was modified") + } + if *clone.Since != 1704067200 { + t.Error("Clone Since was modified") + } + if *clone.Until != 1704153600 { + t.Error("Clone Until was modified") + } + if clone.Limit != 100 { + t.Error("Clone Limit was modified") + } +} + +func TestFilterJSONRoundTrip(t *testing.T) { + since := int64(1704067200) + original := Filter{ + IDs: []string{"abc123"}, + Kinds: []int{1}, + Authors: []string{"def456"}, + Tags: map[string][]string{ + "e": {"event1"}, + }, + Since: &since, + Limit: 50, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var decoded Filter + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if len(decoded.IDs) != 1 || decoded.IDs[0] != "abc123" { + t.Errorf("IDs mismatch") + } + if len(decoded.Kinds) != 1 || decoded.Kinds[0] != 1 { + t.Errorf("Kinds mismatch") + } + if len(decoded.Tags["e"]) != 1 || decoded.Tags["e"][0] != "event1" { + t.Errorf("Tags mismatch") + } + if decoded.Since == nil || *decoded.Since != since { + t.Errorf("Since mismatch") + } + if decoded.Limit != 50 { + t.Errorf("Limit mismatch") + } +} diff --git a/internal/nostr/keys.go b/internal/nostr/keys.go new file mode 100644 index 0000000..3a3fb9c --- /dev/null +++ b/internal/nostr/keys.go @@ -0,0 +1,217 @@ +package nostr + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "strings" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr" +) + +// Key represents a Nostr key, which may be a full private key or public-only. +// Use GenerateKey or ParseKey for private keys, ParsePublicKey for public-only. +type Key struct { + priv *btcec.PrivateKey // nil for public-only keys + pub *btcec.PublicKey // always set +} + +// GenerateKey generates a new random private key. +func GenerateKey() (*Key, error) { + var keyBytes [32]byte + if _, err := rand.Read(keyBytes[:]); err != nil { + return nil, fmt.Errorf("failed to generate random bytes: %w", err) + } + + priv, _ := btcec.PrivKeyFromBytes(keyBytes[:]) + return &Key{ + priv: priv, + pub: priv.PubKey(), + }, nil +} + +// ParseKey parses a private key from hex or nsec (bech32) format. +func ParseKey(s string) (*Key, error) { + var privBytes []byte + + if strings.HasPrefix(s, "nsec1") { + hrp, data, err := Bech32Decode(s) + if err != nil { + return nil, fmt.Errorf("invalid nsec: %w", err) + } + if hrp != "nsec" { + return nil, fmt.Errorf("invalid prefix: expected nsec, got %s", hrp) + } + if len(data) != 32 { + return nil, fmt.Errorf("invalid nsec data length: %d", len(data)) + } + privBytes = data + } else { + var err error + privBytes, err = hex.DecodeString(s) + if err != nil { + return nil, fmt.Errorf("invalid hex: %w", err) + } + } + + if len(privBytes) != 32 { + return nil, fmt.Errorf("private key must be 32 bytes, got %d", len(privBytes)) + } + + priv, _ := btcec.PrivKeyFromBytes(privBytes) + return &Key{ + priv: priv, + pub: priv.PubKey(), + }, nil +} + +// ParsePublicKey parses a public key from hex or npub (bech32) format. +// The returned Key can only verify, not sign. +func ParsePublicKey(s string) (*Key, error) { + var pubBytes []byte + + if strings.HasPrefix(s, "npub1") { + hrp, data, err := Bech32Decode(s) + if err != nil { + return nil, fmt.Errorf("invalid npub: %w", err) + } + if hrp != "npub" { + return nil, fmt.Errorf("invalid prefix: expected npub, got %s", hrp) + } + if len(data) != 32 { + return nil, fmt.Errorf("invalid npub data length: %d", len(data)) + } + pubBytes = data + } else { + var err error + pubBytes, err = hex.DecodeString(s) + if err != nil { + return nil, fmt.Errorf("invalid hex: %w", err) + } + } + + if len(pubBytes) != 32 { + return nil, fmt.Errorf("public key must be 32 bytes, got %d", len(pubBytes)) + } + + pub, err := schnorr.ParsePubKey(pubBytes) + if err != nil { + return nil, fmt.Errorf("invalid public key: %w", err) + } + + return &Key{ + priv: nil, + pub: pub, + }, nil +} + +// CanSign returns true if this key can sign events (has private key). +func (k *Key) CanSign() bool { + return k.priv != nil +} + +// Public returns the public key as a 64-character hex string. +func (k *Key) Public() string { + return hex.EncodeToString(schnorr.SerializePubKey(k.pub)) +} + +// Private returns the private key as a 64-character hex string. +// Returns empty string if this is a public-only key. +func (k *Key) Private() string { + if k.priv == nil { + return "" + } + return hex.EncodeToString(k.priv.Serialize()) +} + +// Npub returns the public key in bech32 npub format. +func (k *Key) Npub() string { + pubBytes := schnorr.SerializePubKey(k.pub) + npub, _ := Bech32Encode("npub", pubBytes) + return npub +} + +// Nsec returns the private key in bech32 nsec format. +// Returns empty string if this is a public-only key. +func (k *Key) Nsec() string { + if k.priv == nil { + return "" + } + nsec, _ := Bech32Encode("nsec", k.priv.Serialize()) + return nsec +} + +// Sign signs the event with this key. +// Sets the PubKey, ID, and Sig fields on the event. +// Returns an error if this is a public-only key. +func (k *Key) Sign(event *Event) error { + if k.priv == nil { + return fmt.Errorf("cannot sign: public-only key") + } + + // Set public key + event.PubKey = k.Public() + + if event.CreatedAt == 0 { + event.CreatedAt = time.Now().Unix() + } + + // Compute ID + event.SetID() + + // Hash the ID for signing + idBytes, err := hex.DecodeString(event.ID) + if err != nil { + return fmt.Errorf("failed to decode event ID: %w", err) + } + + // Sign with Schnorr + sig, err := schnorr.Sign(k.priv, idBytes) + if err != nil { + return fmt.Errorf("failed to sign event: %w", err) + } + + event.Sig = hex.EncodeToString(sig.Serialize()) + return nil +} + +// Verify verifies the event signature. +// Returns true if the signature is valid, false otherwise. +func (e *Event) Verify() bool { + // Verify ID first + if !e.CheckID() { + return false + } + + // Decode public key + pubKeyBytes, err := hex.DecodeString(e.PubKey) + if err != nil || len(pubKeyBytes) != 32 { + return false + } + + pubKey, err := schnorr.ParsePubKey(pubKeyBytes) + if err != nil { + return false + } + + // Decode signature + sigBytes, err := hex.DecodeString(e.Sig) + if err != nil { + return false + } + + sig, err := schnorr.ParseSignature(sigBytes) + if err != nil { + return false + } + + // Decode ID (message hash) + idBytes, err := hex.DecodeString(e.ID) + if err != nil { + return false + } + + return sig.Verify(idBytes, pubKey) +} diff --git a/internal/nostr/keys_test.go b/internal/nostr/keys_test.go new file mode 100644 index 0000000..6c3dd3d --- /dev/null +++ b/internal/nostr/keys_test.go @@ -0,0 +1,333 @@ +package nostr + +import ( + "encoding/hex" + "strings" + "testing" +) + +func TestGenerateKey(t *testing.T) { + key1, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + if !key1.CanSign() { + t.Error("Generated key should be able to sign") + } + + // Private key should be 64 hex characters + if len(key1.Private()) != 64 { + t.Errorf("Private() length = %d, want 64", len(key1.Private())) + } + + // Public key should be 64 hex characters + if len(key1.Public()) != 64 { + t.Errorf("Public() length = %d, want 64", len(key1.Public())) + } + + // Should be valid hex + if _, err := hex.DecodeString(key1.Private()); err != nil { + t.Errorf("Private() is not valid hex: %v", err) + } + if _, err := hex.DecodeString(key1.Public()); err != nil { + t.Errorf("Public() is not valid hex: %v", err) + } + + // Keys should be unique + key2, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() second call error = %v", err) + } + if key1.Private() == key2.Private() { + t.Error("GenerateKey() returned same private key twice") + } +} + +func TestKeyNpubNsec(t *testing.T) { + key, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + npub := key.Npub() + nsec := key.Nsec() + + // Check prefixes + if !strings.HasPrefix(npub, "npub1") { + t.Errorf("Npub() = %s, want prefix 'npub1'", npub) + } + if !strings.HasPrefix(nsec, "nsec1") { + t.Errorf("Nsec() = %s, want prefix 'nsec1'", nsec) + } + + // Should be able to parse them back + keyFromNsec, err := ParseKey(nsec) + if err != nil { + t.Fatalf("ParseKey(nsec) error = %v", err) + } + if keyFromNsec.Private() != key.Private() { + t.Error("ParseKey(nsec) did not restore original private key") + } + + keyFromNpub, err := ParsePublicKey(npub) + if err != nil { + t.Fatalf("ParsePublicKey(npub) error = %v", err) + } + if keyFromNpub.Public() != key.Public() { + t.Error("ParsePublicKey(npub) did not restore original public key") + } +} + +func TestParseKey(t *testing.T) { + // Known test vector + hexKey := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + key, err := ParseKey(hexKey) + if err != nil { + t.Fatalf("ParseKey(hex) error = %v", err) + } + + if !key.CanSign() { + t.Error("ParseKey should return key that can sign") + } + + if key.Private() != hexKey { + t.Errorf("Private() = %s, want %s", key.Private(), hexKey) + } + + // Parse the nsec back + nsec := key.Nsec() + key2, err := ParseKey(nsec) + if err != nil { + t.Fatalf("ParseKey(nsec) error = %v", err) + } + if key2.Private() != hexKey { + t.Error("Round-trip through nsec failed") + } +} + +func TestParseKeyErrors(t *testing.T) { + tests := []struct { + name string + key string + }{ + {"invalid hex", "not-hex"}, + {"too short", "0123456789abcdef"}, + {"too long", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef00"}, + {"invalid nsec", "nsec1invalid"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseKey(tt.key) + if err == nil { + t.Error("ParseKey() expected error, got nil") + } + }) + } +} + +func TestParsePublicKey(t *testing.T) { + // Generate a key and extract public + fullKey, _ := GenerateKey() + pubHex := fullKey.Public() + + // Parse public key from hex + key, err := ParsePublicKey(pubHex) + if err != nil { + t.Fatalf("ParsePublicKey(hex) error = %v", err) + } + + if key.CanSign() { + t.Error("ParsePublicKey should return key that cannot sign") + } + + if key.Public() != pubHex { + t.Errorf("Public() = %s, want %s", key.Public(), pubHex) + } + + if key.Private() != "" { + t.Error("Private() should return empty string for public-only key") + } + + if key.Nsec() != "" { + t.Error("Nsec() should return empty string for public-only key") + } + + // Parse from npub + npub := fullKey.Npub() + key2, err := ParsePublicKey(npub) + if err != nil { + t.Fatalf("ParsePublicKey(npub) error = %v", err) + } + if key2.Public() != pubHex { + t.Error("ParsePublicKey(npub) did not restore correct public key") + } +} + +func TestParsePublicKeyErrors(t *testing.T) { + tests := []struct { + name string + key string + }{ + {"invalid hex", "not-hex"}, + {"too short", "0123456789abcdef"}, + {"invalid npub", "npub1invalid"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParsePublicKey(tt.key) + if err == nil { + t.Error("ParsePublicKey() expected error, got nil") + } + }) + } +} + +func TestKeySign(t *testing.T) { + key, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + event := &Event{ + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{}, + Content: "Test message", + } + + if err := key.Sign(event); err != nil { + t.Fatalf("Sign() error = %v", err) + } + + // Check that all fields are set + if event.PubKey == "" { + t.Error("Sign() did not set PubKey") + } + if event.ID == "" { + t.Error("Sign() did not set ID") + } + if event.Sig == "" { + t.Error("Sign() did not set Sig") + } + + // PubKey should match + if event.PubKey != key.Public() { + t.Errorf("PubKey = %s, want %s", event.PubKey, key.Public()) + } + + // Signature should be 128 hex characters (64 bytes) + if len(event.Sig) != 128 { + t.Errorf("Signature length = %d, want 128", len(event.Sig)) + } +} + +func TestKeySignPublicOnlyError(t *testing.T) { + fullKey, _ := GenerateKey() + pubOnlyKey, _ := ParsePublicKey(fullKey.Public()) + + event := &Event{ + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{}, + Content: "Test", + } + + err := pubOnlyKey.Sign(event) + if err == nil { + t.Error("Sign() with public-only key should return error") + } +} + +func TestEventVerify(t *testing.T) { + key, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + event := &Event{ + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{{"test", "value"}}, + Content: "Test message for verification", + } + + if err := key.Sign(event); err != nil { + t.Fatalf("Sign() error = %v", err) + } + + if !event.Verify() { + t.Error("Verify() returned false for valid signature") + } +} + +func TestEventVerifyInvalid(t *testing.T) { + key, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + event := &Event{ + CreatedAt: 1704067200, + Kind: 1, + Tags: Tags{}, + Content: "Test message", + } + + if err := key.Sign(event); err != nil { + t.Fatalf("Sign() error = %v", err) + } + + // Corrupt the content (ID becomes invalid) + event.Content = "Modified content" + if event.Verify() { + t.Error("Verify() returned true for modified content") + } + + // Restore content but corrupt signature + event.Content = "Test message" + event.SetID() + event.Sig = "0000000000000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000000" + if event.Verify() { + t.Error("Verify() returned true for invalid signature") + } +} + +func TestSignAndVerifyRoundTrip(t *testing.T) { + // Generate key + key, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + // Create and sign event + event := &Event{ + CreatedAt: 1704067200, + Kind: KindTextNote, + Tags: Tags{{"t", "test"}}, + Content: "Integration test message", + } + + if err := key.Sign(event); err != nil { + t.Fatalf("Sign() error = %v", err) + } + + // Verify public key matches + if event.PubKey != key.Public() { + t.Errorf("Signed event PubKey = %s, want %s", event.PubKey, key.Public()) + } + + // Verify the signature + if !event.Verify() { + t.Error("Verify() failed for freshly signed event") + } + + // Check ID is correct + if !event.CheckID() { + t.Error("CheckID() failed for freshly signed event") + } +} diff --git a/internal/nostr/kinds.go b/internal/nostr/kinds.go new file mode 100644 index 0000000..cb76e88 --- /dev/null +++ b/internal/nostr/kinds.go @@ -0,0 +1,51 @@ +package nostr + +// Event kind constants as defined in NIP-01 and related NIPs. +const ( + KindMetadata = 0 + KindTextNote = 1 + KindContactList = 3 + KindEncryptedDM = 4 + KindDeletion = 5 + KindRepost = 6 + KindReaction = 7 +) + +// IsRegular returns true if the kind is a regular event (stored, not replaced). +// Regular events: 1000 <= kind < 10000 or kind in {0,1,2,...} except replaceable ones. +func IsRegular(kind int) bool { + if kind == KindMetadata || kind == KindContactList { + return false + } + if kind >= 10000 && kind < 20000 { + return false // replaceable + } + if kind >= 20000 && kind < 30000 { + return false // ephemeral + } + if kind >= 30000 && kind < 40000 { + return false // addressable + } + return true +} + +// IsReplaceable returns true if the kind is replaceable (NIP-01). +// Replaceable events: 10000 <= kind < 20000, or kind 0 (metadata) or kind 3 (contact list). +func IsReplaceable(kind int) bool { + if kind == KindMetadata || kind == KindContactList { + return true + } + return kind >= 10000 && kind < 20000 +} + +// IsEphemeral returns true if the kind is ephemeral (not stored). +// Ephemeral events: 20000 <= kind < 30000. +func IsEphemeral(kind int) bool { + return kind >= 20000 && kind < 30000 +} + +// IsAddressable returns true if the kind is addressable (parameterized replaceable). +// Addressable events: 30000 <= kind < 40000. +func IsAddressable(kind int) bool { + return kind >= 30000 && kind < 40000 +} diff --git a/internal/nostr/kinds_test.go b/internal/nostr/kinds_test.go new file mode 100644 index 0000000..2bf013d --- /dev/null +++ b/internal/nostr/kinds_test.go @@ -0,0 +1,128 @@ +package nostr + +import ( + "testing" +) + +func TestKindConstants(t *testing.T) { + // Verify constants match NIP-01 spec + tests := []struct { + name string + kind int + value int + }{ + {"Metadata", KindMetadata, 0}, + {"TextNote", KindTextNote, 1}, + {"ContactList", KindContactList, 3}, + {"EncryptedDM", KindEncryptedDM, 4}, + {"Deletion", KindDeletion, 5}, + {"Repost", KindRepost, 6}, + {"Reaction", KindReaction, 7}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.kind != tt.value { + t.Errorf("Kind%s = %d, want %d", tt.name, tt.kind, tt.value) + } + }) + } +} + +func TestIsRegular(t *testing.T) { + tests := []struct { + kind int + want bool + }{ + {0, false}, // Metadata - replaceable + {1, true}, // TextNote - regular + {3, false}, // ContactList - replaceable + {4, true}, // EncryptedDM - regular + {5, true}, // Deletion - regular + {1000, true}, // Regular range + {9999, true}, // Regular range + {10000, false}, // Replaceable range + {19999, false}, // Replaceable range + {20000, false}, // Ephemeral range + {29999, false}, // Ephemeral range + {30000, false}, // Addressable range + {39999, false}, // Addressable range + {40000, true}, // Back to regular + } + + for _, tt := range tests { + t.Run("kind_"+string(rune(tt.kind)), func(t *testing.T) { + if got := IsRegular(tt.kind); got != tt.want { + t.Errorf("IsRegular(%d) = %v, want %v", tt.kind, got, tt.want) + } + }) + } +} + +func TestIsReplaceable(t *testing.T) { + tests := []struct { + kind int + want bool + }{ + {0, true}, // Metadata + {1, false}, // TextNote + {3, true}, // ContactList + {10000, true}, // Replaceable range start + {15000, true}, // Replaceable range middle + {19999, true}, // Replaceable range end + {20000, false}, // Ephemeral range + {30000, false}, // Addressable range + } + + for _, tt := range tests { + t.Run("kind_"+string(rune(tt.kind)), func(t *testing.T) { + if got := IsReplaceable(tt.kind); got != tt.want { + t.Errorf("IsReplaceable(%d) = %v, want %v", tt.kind, got, tt.want) + } + }) + } +} + +func TestIsEphemeral(t *testing.T) { + tests := []struct { + kind int + want bool + }{ + {1, false}, // TextNote + {19999, false}, // Replaceable range + {20000, true}, // Ephemeral range start + {25000, true}, // Ephemeral range middle + {29999, true}, // Ephemeral range end + {30000, false}, // Addressable range + } + + for _, tt := range tests { + t.Run("kind_"+string(rune(tt.kind)), func(t *testing.T) { + if got := IsEphemeral(tt.kind); got != tt.want { + t.Errorf("IsEphemeral(%d) = %v, want %v", tt.kind, got, tt.want) + } + }) + } +} + +func TestIsAddressable(t *testing.T) { + tests := []struct { + kind int + want bool + }{ + {1, false}, // TextNote + {29999, false}, // Ephemeral range + {30000, true}, // Addressable range start + {35000, true}, // Addressable range middle + {39999, true}, // Addressable range end + {40000, false}, // Beyond addressable range + } + + for _, tt := range tests { + t.Run("kind_"+string(rune(tt.kind)), func(t *testing.T) { + if got := IsAddressable(tt.kind); got != tt.want { + t.Errorf("IsAddressable(%d) = %v, want %v", tt.kind, got, tt.want) + } + }) + } +} diff --git a/internal/nostr/relay.go b/internal/nostr/relay.go new file mode 100644 index 0000000..2b156e0 --- /dev/null +++ b/internal/nostr/relay.go @@ -0,0 +1,305 @@ +package nostr + +import ( + "context" + "crypto/rand" + "fmt" + "sync" + + "northwest.io/nostr-grpc/internal/websocket" +) + +// Relay represents a connection to a Nostr relay. +type Relay struct { + URL string + conn *websocket.Conn + mu sync.Mutex + + subscriptions map[string]*Subscription + subscriptionsMu sync.RWMutex + + okChannels map[string]chan *OKEnvelope + okChannelsMu sync.Mutex +} + +// Connect establishes a WebSocket connection to the relay. +func Connect(ctx context.Context, url string) (*Relay, error) { + conn, err := websocket.Dial(ctx, url) + if err != nil { + return nil, fmt.Errorf("failed to connect to relay: %w", err) + } + + r := &Relay{ + URL: url, + conn: conn, + subscriptions: make(map[string]*Subscription), + okChannels: make(map[string]chan *OKEnvelope), + } + + go r.Listen(ctx) + + return r, nil +} + +// Close closes the WebSocket connection. +func (r *Relay) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.conn == nil { + return nil + } + + err := r.conn.Close(websocket.StatusNormalClosure, "") + r.conn = nil + return err +} + +// Send sends an envelope to the relay. +func (r *Relay) Send(ctx context.Context, env Envelope) error { + data, err := env.MarshalJSON() + if err != nil { + return fmt.Errorf("failed to marshal envelope: %w", err) + } + + r.mu.Lock() + defer r.mu.Unlock() + + if r.conn == nil { + return fmt.Errorf("connection closed") + } + + return r.conn.Write(ctx, websocket.MessageText, data) +} + +// Receive reads the next envelope from the relay. +func (r *Relay) Receive(ctx context.Context) (Envelope, error) { + r.mu.Lock() + conn := r.conn + r.mu.Unlock() + + if conn == nil { + return nil, fmt.Errorf("connection closed") + } + + _, data, err := conn.Read(ctx) + if err != nil { + return nil, fmt.Errorf("failed to read message: %w", err) + } + + return ParseEnvelope(data) +} + +// Publish sends an event to the relay and waits for the OK response. +func (r *Relay) Publish(ctx context.Context, event *Event) error { + ch := make(chan *OKEnvelope, 1) + + r.okChannelsMu.Lock() + r.okChannels[event.ID] = ch + r.okChannelsMu.Unlock() + + defer func() { + r.okChannelsMu.Lock() + delete(r.okChannels, event.ID) + r.okChannelsMu.Unlock() + }() + + env := EventEnvelope{Event: event} + if err := r.Send(ctx, env); err != nil { + return fmt.Errorf("failed to send event: %w", err) + } + + select { + case ok := <-ch: + if !ok.OK { + return fmt.Errorf("event rejected: %s", ok.Message) + } + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func genID() string { + buf := make([]byte, 5) + rand.Read(buf) + return fmt.Sprintf("%x", buf) +} + +// subscribe is the internal implementation for Subscribe and Fetch. +func (r *Relay) subscribe(ctx context.Context, closeOnEOSE bool, filters ...Filter) *Subscription { + id := genID() + + sub := &Subscription{ + ID: id, + relay: r, + Filters: filters, + Events: make(chan *Event, 100), + closeOnEOSE: closeOnEOSE, + } + + r.subscriptionsMu.Lock() + r.subscriptions[id] = sub + r.subscriptionsMu.Unlock() + + go func() { + <-ctx.Done() + sub.stop(ctx.Err()) + r.subscriptionsMu.Lock() + delete(r.subscriptions, id) + r.subscriptionsMu.Unlock() + }() + + env := ReqEnvelope{ + SubscriptionID: id, + Filters: filters, + } + if err := r.Send(ctx, env); err != nil { + r.subscriptionsMu.Lock() + delete(r.subscriptions, id) + r.subscriptionsMu.Unlock() + sub.stop(fmt.Errorf("failed to send subscription request: %w", err)) + } + + return sub +} + +// Subscribe creates a subscription with the given filters. +// Events are received on the Events channel until the context is cancelled. +// After EOSE (end of stored events), the subscription continues to receive +// real-time events per the Nostr protocol. +func (r *Relay) Subscribe(ctx context.Context, filters ...Filter) *Subscription { + return r.subscribe(ctx, false, filters...) +} + +// Fetch creates a subscription that closes automatically when EOSE is received. +// Use this for one-shot queries where you only want stored events. +func (r *Relay) Fetch(ctx context.Context, filters ...Filter) *Subscription { + return r.subscribe(ctx, true, filters...) +} + +// dispatchEnvelope routes incoming messages to the appropriate subscription. +func (r *Relay) dispatchEnvelope(env Envelope) { + switch e := env.(type) { + case *EventEnvelope: + r.subscriptionsMu.RLock() + sub, ok := r.subscriptions[e.SubscriptionID] + r.subscriptionsMu.RUnlock() + if ok { + sub.send(e.Event) + } + case *EOSEEnvelope: + r.subscriptionsMu.RLock() + sub, ok := r.subscriptions[e.SubscriptionID] + r.subscriptionsMu.RUnlock() + if ok && sub.closeOnEOSE { + r.subscriptionsMu.Lock() + delete(r.subscriptions, e.SubscriptionID) + r.subscriptionsMu.Unlock() + sub.stop(nil) + } + case *ClosedEnvelope: + r.subscriptionsMu.Lock() + sub, ok := r.subscriptions[e.SubscriptionID] + if ok { + delete(r.subscriptions, e.SubscriptionID) + } + r.subscriptionsMu.Unlock() + if ok { + sub.stop(fmt.Errorf("subscription closed by relay: %s", e.Message)) + } + case *OKEnvelope: + r.okChannelsMu.Lock() + ch, ok := r.okChannels[e.EventID] + r.okChannelsMu.Unlock() + if ok { + select { + case ch <- e: + default: + } + } + } +} + +// Listen reads messages from the relay and dispatches them to subscriptions. +func (r *Relay) Listen(ctx context.Context) error { + defer func() { + r.subscriptionsMu.Lock() + subs := make([]*Subscription, 0, len(r.subscriptions)) + for id, sub := range r.subscriptions { + subs = append(subs, sub) + delete(r.subscriptions, id) + } + r.subscriptionsMu.Unlock() + + for _, sub := range subs { + sub.stop(fmt.Errorf("connection closed")) + } + }() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + env, err := r.Receive(ctx) + if err != nil { + return err + } + + r.dispatchEnvelope(env) + } +} + +// Subscription represents an active subscription to a relay. +type Subscription struct { + ID string + relay *Relay + Filters []Filter + Events chan *Event + Err error + + closeOnEOSE bool + mu sync.Mutex + done bool +} + +// send delivers an event to the subscription's Events channel. +func (s *Subscription) send(ev *Event) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done { + return + } + select { + case s.Events <- ev: + default: + } +} + +// stop closes the subscription's Events channel and sets the error. +// It is idempotent — only the first call has any effect. +func (s *Subscription) stop(err error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done { + return + } + s.done = true + s.Err = err + close(s.Events) +} + +// Close unsubscribes from the relay. +func (s *Subscription) Close(ctx context.Context) error { + s.stop(nil) + + s.relay.subscriptionsMu.Lock() + delete(s.relay.subscriptions, s.ID) + s.relay.subscriptionsMu.Unlock() + + env := CloseEnvelope{SubscriptionID: s.ID} + return s.relay.Send(ctx, env) +} diff --git a/internal/nostr/relay_test.go b/internal/nostr/relay_test.go new file mode 100644 index 0000000..02bd8e5 --- /dev/null +++ b/internal/nostr/relay_test.go @@ -0,0 +1,326 @@ +package nostr + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "northwest.io/nostr-grpc/internal/websocket" +) + +// mockRelay creates a test WebSocket server that echoes messages +func mockRelay(t *testing.T, handler func(conn *websocket.Conn)) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r) + if err != nil { + t.Logf("Failed to accept WebSocket: %v", err) + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + handler(conn) + })) +} + +func TestConnect(t *testing.T) { + server := mockRelay(t, func(conn *websocket.Conn) { + // Just accept and wait + time.Sleep(100 * time.Millisecond) + }) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + ctx := context.Background() + + relay, err := Connect(ctx, url) + if err != nil { + t.Fatalf("Connect() error = %v", err) + } + defer relay.Close() + + if relay.URL != url { + t.Errorf("Relay.URL = %s, want %s", relay.URL, url) + } +} + +func TestConnectError(t *testing.T) { + ctx := context.Background() + _, err := Connect(ctx, "ws://localhost:99999") + if err == nil { + t.Error("Connect() expected error for invalid URL") + } +} + +func TestRelaySendReceive(t *testing.T) { + server := mockRelay(t, func(conn *websocket.Conn) { + // Read message + _, data, err := conn.Read(context.Background()) + if err != nil { + t.Logf("Read error: %v", err) + return + } + + // Echo it back as NOTICE + var arr []interface{} + json.Unmarshal(data, &arr) + + response, _ := json.Marshal([]interface{}{"NOTICE", "received: " + arr[0].(string)}) + conn.Write(context.Background(), websocket.MessageText, response) + }) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + ctx := context.Background() + + // Create relay without auto-Listen to test Send/Receive directly + conn, err := websocket.Dial(ctx, url) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + relay := &Relay{ + URL: url, + conn: conn, + subscriptions: make(map[string]*Subscription), + okChannels: make(map[string]chan *OKEnvelope), + } + defer relay.Close() + + // Send a CLOSE envelope + closeEnv := CloseEnvelope{SubscriptionID: "test"} + if err := relay.Send(ctx, closeEnv); err != nil { + t.Fatalf("Send() error = %v", err) + } + + // Receive response + env, err := relay.Receive(ctx) + if err != nil { + t.Fatalf("Receive() error = %v", err) + } + + noticeEnv, ok := env.(*NoticeEnvelope) + if !ok { + t.Fatalf("Expected *NoticeEnvelope, got %T", env) + } + + if !strings.Contains(noticeEnv.Message, "CLOSE") { + t.Errorf("Message = %s, want to contain 'CLOSE'", noticeEnv.Message) + } +} + +func TestRelayPublish(t *testing.T) { + server := mockRelay(t, func(conn *websocket.Conn) { + // Read the EVENT message + _, data, err := conn.Read(context.Background()) + if err != nil { + t.Logf("Read error: %v", err) + return + } + + // Parse to get event ID + var arr []json.RawMessage + json.Unmarshal(data, &arr) + + var event Event + json.Unmarshal(arr[1], &event) + + // Send OK response + response, _ := json.Marshal([]interface{}{"OK", event.ID, true, ""}) + conn.Write(context.Background(), websocket.MessageText, response) + }) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + ctx := context.Background() + + relay, err := Connect(ctx, url) + if err != nil { + t.Fatalf("Connect() error = %v", err) + } + defer relay.Close() + + // Create and sign event + key, _ := GenerateKey() + event := &Event{ + CreatedAt: time.Now().Unix(), + Kind: KindTextNote, + Tags: Tags{}, + Content: "Test publish", + } + key.Sign(event) + + // Publish + if err := relay.Publish(ctx, event); err != nil { + t.Fatalf("Publish() error = %v", err) + } +} + +func TestRelayPublishRejected(t *testing.T) { + server := mockRelay(t, func(conn *websocket.Conn) { + // Read the EVENT message + _, data, err := conn.Read(context.Background()) + if err != nil { + return + } + + var arr []json.RawMessage + json.Unmarshal(data, &arr) + + var event Event + json.Unmarshal(arr[1], &event) + + // Send rejection + response, _ := json.Marshal([]interface{}{"OK", event.ID, false, "blocked: spam"}) + conn.Write(context.Background(), websocket.MessageText, response) + }) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + ctx := context.Background() + + relay, err := Connect(ctx, url) + if err != nil { + t.Fatalf("Connect() error = %v", err) + } + defer relay.Close() + + key, _ := GenerateKey() + event := &Event{ + CreatedAt: time.Now().Unix(), + Kind: KindTextNote, + Tags: Tags{}, + Content: "Test", + } + key.Sign(event) + + err = relay.Publish(ctx, event) + if err == nil { + t.Error("Publish() expected error for rejected event") + } + if !strings.Contains(err.Error(), "rejected") { + t.Errorf("Error = %v, want to contain 'rejected'", err) + } +} + +func TestRelaySubscribe(t *testing.T) { + server := mockRelay(t, func(conn *websocket.Conn) { + // Read REQ + _, data, err := conn.Read(context.Background()) + if err != nil { + return + } + + var arr []json.RawMessage + json.Unmarshal(data, &arr) + + var subID string + json.Unmarshal(arr[1], &subID) + + // Send some events + for i := 0; i < 3; i++ { + event := Event{ + ID: "event" + string(rune('0'+i)), + PubKey: "pubkey", + CreatedAt: time.Now().Unix(), + Kind: 1, + Tags: Tags{}, + Content: "Test event", + Sig: "sig", + } + response, _ := json.Marshal([]interface{}{"EVENT", subID, event}) + conn.Write(context.Background(), websocket.MessageText, response) + } + + // Send EOSE + eose, _ := json.Marshal([]interface{}{"EOSE", subID}) + conn.Write(context.Background(), websocket.MessageText, eose) + }) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + relay, err := Connect(ctx, url) + if err != nil { + t.Fatalf("Connect() error = %v", err) + } + defer relay.Close() + + sub := relay.Fetch(ctx, Filter{Kinds: []int{1}}) + + eventCount := 0 + for range sub.Events { + eventCount++ + } + + if eventCount != 3 { + t.Errorf("Received %d events, want 3", eventCount) + } + if sub.Err != nil { + t.Errorf("Subscription.Err = %v, want nil", sub.Err) + } +} + +func TestRelayClose(t *testing.T) { + server := mockRelay(t, func(conn *websocket.Conn) { + time.Sleep(100 * time.Millisecond) + }) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + ctx := context.Background() + + relay, err := Connect(ctx, url) + if err != nil { + t.Fatalf("Connect() error = %v", err) + } + + if err := relay.Close(); err != nil { + t.Errorf("Close() error = %v", err) + } + + // Second close should be safe + if err := relay.Close(); err != nil { + t.Errorf("Second Close() error = %v", err) + } +} + +func TestSubscriptionClose(t *testing.T) { + server := mockRelay(t, func(conn *websocket.Conn) { + // Read REQ + conn.Read(context.Background()) + + // Wait for CLOSE + _, data, err := conn.Read(context.Background()) + if err != nil { + return + } + + var arr []interface{} + json.Unmarshal(data, &arr) + + if arr[0] != "CLOSE" { + t.Errorf("Expected CLOSE, got %v", arr[0]) + } + }) + defer server.Close() + + url := "ws" + strings.TrimPrefix(server.URL, "http") + ctx := context.Background() + + relay, err := Connect(ctx, url) + if err != nil { + t.Fatalf("Connect() error = %v", err) + } + defer relay.Close() + + sub := relay.Subscribe(ctx, Filter{Kinds: []int{1}}) + + if err := sub.Close(ctx); err != nil { + t.Errorf("Subscription.Close() error = %v", err) + } +} diff --git a/internal/nostr/tags.go b/internal/nostr/tags.go new file mode 100644 index 0000000..4fe3d04 --- /dev/null +++ b/internal/nostr/tags.go @@ -0,0 +1,64 @@ +package nostr + +// Tag represents a single Nostr tag, which is an array of strings. +// The first element is the tag key, followed by its values. +type Tag []string + +// Key returns the tag key (first element), or empty string if tag is empty. +func (t Tag) Key() string { + if len(t) == 0 { + return "" + } + return t[0] +} + +// Value returns the first value (second element), or empty string if not present. +func (t Tag) Value() string { + if len(t) < 2 { + return "" + } + return t[1] +} + +// Tags represents a collection of tags. +type Tags []Tag + +// Find returns the first tag matching the given key, or nil if not found. +func (tags Tags) Find(key string) Tag { + for _, tag := range tags { + if tag.Key() == key { + return tag + } + } + return nil +} + +// FindAll returns all tags matching the given key. +func (tags Tags) FindAll(key string) Tags { + var result Tags + for _, tag := range tags { + if tag.Key() == key { + result = append(result, tag) + } + } + return result +} + +// GetD returns the value of the "d" tag, used for addressable events. +func (tags Tags) GetD() string { + tag := tags.Find("d") + if tag == nil { + return "" + } + return tag.Value() +} + +// ContainsValue checks if any tag with the given key contains the specified value. +func (tags Tags) ContainsValue(key, value string) bool { + for _, tag := range tags { + if tag.Key() == key && tag.Value() == value { + return true + } + } + return false +} diff --git a/internal/nostr/tags_test.go b/internal/nostr/tags_test.go new file mode 100644 index 0000000..7796606 --- /dev/null +++ b/internal/nostr/tags_test.go @@ -0,0 +1,158 @@ +package nostr + +import ( + "testing" +) + +func TestTagKey(t *testing.T) { + tests := []struct { + name string + tag Tag + want string + }{ + {"empty tag", Tag{}, ""}, + {"single element", Tag{"e"}, "e"}, + {"multiple elements", Tag{"p", "abc123", "relay"}, "p"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.tag.Key(); got != tt.want { + t.Errorf("Tag.Key() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestTagValue(t *testing.T) { + tests := []struct { + name string + tag Tag + want string + }{ + {"empty tag", Tag{}, ""}, + {"single element", Tag{"e"}, ""}, + {"two elements", Tag{"p", "abc123"}, "abc123"}, + {"multiple elements", Tag{"e", "eventid", "relay", "marker"}, "eventid"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.tag.Value(); got != tt.want { + t.Errorf("Tag.Value() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestTagsFind(t *testing.T) { + tags := Tags{ + {"e", "event1"}, + {"p", "pubkey1"}, + {"e", "event2"}, + {"d", "identifier"}, + } + + tests := []struct { + name string + key string + wantNil bool + wantVal string + }{ + {"find first e", "e", false, "event1"}, + {"find p", "p", false, "pubkey1"}, + {"find d", "d", false, "identifier"}, + {"find nonexistent", "x", true, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tags.Find(tt.key) + if tt.wantNil { + if got != nil { + t.Errorf("Tags.Find(%q) = %v, want nil", tt.key, got) + } + } else { + if got == nil { + t.Errorf("Tags.Find(%q) = nil, want value %q", tt.key, tt.wantVal) + } else if got.Value() != tt.wantVal { + t.Errorf("Tags.Find(%q).Value() = %q, want %q", tt.key, got.Value(), tt.wantVal) + } + } + }) + } +} + +func TestTagsFindAll(t *testing.T) { + tags := Tags{ + {"e", "event1"}, + {"p", "pubkey1"}, + {"e", "event2"}, + {"e", "event3"}, + } + + found := tags.FindAll("e") + if len(found) != 3 { + t.Errorf("Tags.FindAll(\"e\") returned %d tags, want 3", len(found)) + } + + found = tags.FindAll("p") + if len(found) != 1 { + t.Errorf("Tags.FindAll(\"p\") returned %d tags, want 1", len(found)) + } + + found = tags.FindAll("x") + if len(found) != 0 { + t.Errorf("Tags.FindAll(\"x\") returned %d tags, want 0", len(found)) + } +} + +func TestTagsGetD(t *testing.T) { + tests := []struct { + name string + tags Tags + want string + }{ + {"no d tag", Tags{{"e", "event1"}}, ""}, + {"empty d tag", Tags{{"d"}}, ""}, + {"d tag present", Tags{{"d", "my-identifier"}}, "my-identifier"}, + {"d tag with extras", Tags{{"d", "id", "extra"}}, "id"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.tags.GetD(); got != tt.want { + t.Errorf("Tags.GetD() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestTagsContainsValue(t *testing.T) { + tags := Tags{ + {"e", "event1"}, + {"p", "pubkey1"}, + {"e", "event2"}, + } + + tests := []struct { + key string + value string + want bool + }{ + {"e", "event1", true}, + {"e", "event2", true}, + {"e", "event3", false}, + {"p", "pubkey1", true}, + {"p", "pubkey2", false}, + {"x", "anything", false}, + } + + for _, tt := range tests { + t.Run(tt.key+"="+tt.value, func(t *testing.T) { + if got := tags.ContainsValue(tt.key, tt.value); got != tt.want { + t.Errorf("Tags.ContainsValue(%q, %q) = %v, want %v", tt.key, tt.value, got, tt.want) + } + }) + } +} diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go new file mode 100644 index 0000000..fe937c8 --- /dev/null +++ b/internal/websocket/websocket.go @@ -0,0 +1,297 @@ +package websocket + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/sha1" + "crypto/tls" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +type MessageType int + +const MessageText MessageType = 1 + +type StatusCode int + +const StatusNormalClosure StatusCode = 1000 + +const ( + opText = 0x1 + opClose = 0x8 + opPing = 0x9 + opPong = 0xA +) + +type Conn struct { + rwc net.Conn + br *bufio.Reader + client bool + mu sync.Mutex +} + +func mask(key [4]byte, data []byte) { + for i := range data { + data[i] ^= key[i%4] + } +} + +func (c *Conn) writeFrame(opcode byte, payload []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + length := len(payload) + header := []byte{0x80 | opcode, 0} // FIN + opcode + + if c.client { + header[1] = 0x80 // mask bit + } + + switch { + case length <= 125: + header[1] |= byte(length) + case length <= 65535: + header[1] |= 126 + ext := make([]byte, 2) + binary.BigEndian.PutUint16(ext, uint16(length)) + header = append(header, ext...) + default: + header[1] |= 127 + ext := make([]byte, 8) + binary.BigEndian.PutUint64(ext, uint64(length)) + header = append(header, ext...) + } + + if c.client { + var key [4]byte + rand.Read(key[:]) + header = append(header, key[:]...) + mask(key, payload) + } + + if _, err := c.rwc.Write(header); err != nil { + return err + } + _, err := c.rwc.Write(payload) + return err +} + +func (c *Conn) readFrame() (fin bool, opcode byte, payload []byte, err error) { + var hdr [2]byte + if _, err = io.ReadFull(c.br, hdr[:]); err != nil { + return + } + + fin = hdr[0]&0x80 != 0 + opcode = hdr[0] & 0x0F + masked := hdr[1]&0x80 != 0 + length := uint64(hdr[1] & 0x7F) + + switch length { + case 126: + var ext [2]byte + if _, err = io.ReadFull(c.br, ext[:]); err != nil { + return + } + length = uint64(binary.BigEndian.Uint16(ext[:])) + case 127: + var ext [8]byte + if _, err = io.ReadFull(c.br, ext[:]); err != nil { + return + } + length = binary.BigEndian.Uint64(ext[:]) + } + + var key [4]byte + if masked { + if _, err = io.ReadFull(c.br, key[:]); err != nil { + return + } + } + + payload = make([]byte, length) + if _, err = io.ReadFull(c.br, payload); err != nil { + return + } + + if masked { + mask(key, payload) + } + return +} + +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + stop := context.AfterFunc(ctx, func() { + c.rwc.SetReadDeadline(time.Now()) + }) + defer stop() + + var buf []byte + for { + fin, opcode, payload, err := c.readFrame() + if err != nil { + if ctx.Err() != nil { + return 0, nil, ctx.Err() + } + return 0, nil, err + } + + switch opcode { + case opPing: + c.writeFrame(opPong, payload) + continue + case opClose: + return 0, nil, fmt.Errorf("websocket: close frame received") + case opText, 0x0: // text or continuation + buf = append(buf, payload...) + if fin { + return MessageText, buf, nil + } + default: + buf = append(buf, payload...) + if fin { + return MessageText, buf, nil + } + } + } +} + +func (c *Conn) Write(ctx context.Context, typ MessageType, data []byte) error { + return c.writeFrame(byte(typ), data) +} + +func (c *Conn) Close(code StatusCode, reason string) error { + payload := make([]byte, 2+len(reason)) + binary.BigEndian.PutUint16(payload, uint16(code)) + copy(payload[2:], reason) + c.writeFrame(opClose, payload) + return c.rwc.Close() +} + +var wsGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + +func acceptKey(key string) string { + h := sha1.New() + h.Write([]byte(key)) + h.Write([]byte(wsGUID)) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func Dial(ctx context.Context, rawURL string) (*Conn, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + + host := u.Hostname() + port := u.Port() + useTLS := u.Scheme == "wss" + + if port == "" { + if useTLS { + port = "443" + } else { + port = "80" + } + } + + addr := net.JoinHostPort(host, port) + rwc, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + + if useTLS { + tc := tls.Client(rwc, &tls.Config{ServerName: host}) + if err := tc.HandshakeContext(ctx); err != nil { + rwc.Close() + return nil, err + } + rwc = tc + } + + br := bufio.NewReader(rwc) + + var keyBytes [16]byte + rand.Read(keyBytes[:]) + key := base64.StdEncoding.EncodeToString(keyBytes[:]) + + path := u.RequestURI() + reqStr := "GET " + path + " HTTP/1.1\r\n" + + "Host: " + host + "\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Key: " + key + "\r\n" + + "Sec-WebSocket-Version: 13\r\n\r\n" + + if _, err := rwc.Write([]byte(reqStr)); err != nil { + rwc.Close() + return nil, err + } + + req := &http.Request{Method: "GET"} + resp, err := http.ReadResponse(br, req) + if err != nil { + rwc.Close() + return nil, err + } + resp.Body.Close() + + if resp.StatusCode != 101 { + rwc.Close() + return nil, fmt.Errorf("websocket: bad handshake status %d", resp.StatusCode) + } + + got := resp.Header.Get("Sec-WebSocket-Accept") + want := acceptKey(key) + if got != want { + rwc.Close() + return nil, fmt.Errorf("websocket: invalid Sec-WebSocket-Accept") + } + + return &Conn{rwc: rwc, br: br, client: true}, nil +} + +func Accept(w http.ResponseWriter, r *http.Request) (*Conn, error) { + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + return nil, fmt.Errorf("websocket: missing Upgrade header") + } + + key := r.Header.Get("Sec-WebSocket-Key") + if key == "" { + return nil, fmt.Errorf("websocket: missing Sec-WebSocket-Key") + } + + hj, ok := w.(http.Hijacker) + if !ok { + return nil, fmt.Errorf("websocket: response does not support hijacking") + } + + rwc, brw, err := hj.Hijack() + if err != nil { + return nil, err + } + + accept := acceptKey(key) + respStr := "HTTP/1.1 101 Switching Protocols\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Accept: " + accept + "\r\n\r\n" + + if _, err := rwc.Write([]byte(respStr)); err != nil { + rwc.Close() + return nil, err + } + + return &Conn{rwc: rwc, br: brw.Reader, client: false}, nil +} -- cgit v1.2.3