From e79f9ad89556000521b43ce5ff4eb59dd00768b0 Mon Sep 17 00:00:00 2001 From: bndw Date: Sat, 7 Feb 2026 21:22:51 -0800 Subject: refactor: race-safe Subscribe/Fetch API with channel-based Publish - Add mutex-guarded send/stop on Subscription to prevent send-on-closed-channel panics and data races - Split Subscribe (streams after EOSE) and Fetch (closes on EOSE) per NIP-01 - Rewrite Publish to use channel-based OK dispatch instead of calling Receive directly, which raced with the auto-started Listen goroutine - Clean up all subscriptions when Listen exits so range loops terminate - Update tests and examples for new API --- example_test.go | 35 +++------- examples/basic/main.go | 39 ++--------- relay.go | 186 ++++++++++++++++++++++++++++++++++++------------- relay_test.go | 51 ++++++-------- 4 files changed, 176 insertions(+), 135 deletions(-) diff --git a/example_test.go b/example_test.go index 90dae0f..6d10ced 100644 --- a/example_test.go +++ b/example_test.go @@ -53,8 +53,7 @@ func Example_basic() { // ExampleRelay demonstrates connecting to a relay (requires network). // This is a documentation example - run with: go test -v -run ExampleRelay func ExampleRelay() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + ctx := context.Background() // Connect to a public relay relay, err := nostr.Connect(ctx, "wss://relay.damus.io") @@ -66,35 +65,21 @@ func ExampleRelay() { fmt.Println("Connected to relay!") - // Subscribe to recent text notes + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + // Fetch recent text notes (closes on EOSE) since := time.Now().Add(-1 * time.Hour).Unix() - sub, err := relay.Subscribe(ctx, "my-sub", nostr.Filter{ + sub := relay.Fetch(ctx, nostr.Filter{ Kinds: []int{nostr.KindTextNote}, Since: &since, Limit: 5, }) - if err != nil { - fmt.Printf("Failed to subscribe: %v\n", err) - return - } - - // Listen for events in the background - go relay.Listen(ctx) - // Collect events until EOSE eventCount := 0 - for { - select { - case event := <-sub.Events: - eventCount++ - fmt.Printf("Received event from %s...\n", event.PubKey[:8]) - case <-sub.EOSE: - fmt.Printf("Received %d events before EOSE\n", eventCount) - sub.Close(ctx) - return - case <-ctx.Done(): - fmt.Println("Timeout") - return - } + for event := range sub.Events { + eventCount++ + fmt.Printf("Received event from %s...\n", event.PubKey[:8]) } + fmt.Printf("Received %d events\n", eventCount) } diff --git a/examples/basic/main.go b/examples/basic/main.go index 0c99dd9..1a4061a 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -53,11 +53,8 @@ func main() { ExampleRelay() } -// ExampleRelay demonstrates connecting to a relay (requires network). -// This is a documentation example - run with: go test -v -run ExampleRelay func ExampleRelay() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + ctx := context.Background() // Connect to a public relay relay, err := nostr.Connect(ctx, "wss://relay.damus.io") @@ -66,38 +63,16 @@ func ExampleRelay() { return } defer relay.Close() - fmt.Println("Connected to relay!") - // Subscribe to recent text notes - since := time.Now().Add(-1 * time.Hour).Unix() - sub, err := relay.Subscribe(ctx, "my-sub", nostr.Filter{ + ctx, cancel := context.WithTimeout(ctx, 25*time.Second) + defer cancel() + + filter := nostr.Filter{ Kinds: []int{nostr.KindTextNote}, - Since: &since, Limit: 5, - }) - if err != nil { - fmt.Printf("Failed to subscribe: %v\n", err) - os.Exit(1) } - - // Listen for events in the background - go relay.Listen(ctx) - - // Collect events until EOSE - eventCount := 0 - for { - select { - case event := <-sub.Events: - eventCount++ - fmt.Printf("Received event from %s...\n", event) - case <-sub.EOSE: - fmt.Printf("Received %d events before EOSE\n", eventCount) - sub.Close(ctx) - return - case <-ctx.Done(): - fmt.Println("Timeout") - return - } + for event := range relay.Fetch(ctx, filter).Events { + fmt.Printf("Received event from %s...\n", event) } } diff --git a/relay.go b/relay.go index 45f6119..bda76af 100644 --- a/relay.go +++ b/relay.go @@ -2,6 +2,7 @@ package nostr import ( "context" + "crypto/rand" "fmt" "sync" @@ -16,6 +17,9 @@ type Relay struct { subscriptions map[string]*Subscription subscriptionsMu sync.RWMutex + + okChannels map[string]chan *OKEnvelope + okChannelsMu sync.Mutex } // Connect establishes a WebSocket connection to the relay. @@ -25,11 +29,16 @@ func Connect(ctx context.Context, url string) (*Relay, error) { return nil, fmt.Errorf("failed to connect to relay: %w", err) } - return &Relay{ + r := &Relay{ URL: url, conn: conn, subscriptions: make(map[string]*Subscription), - }, nil + okChannels: make(map[string]chan *OKEnvelope), + } + + go r.Listen(ctx) + + return r, nil } // Close closes the WebSocket connection. @@ -83,47 +92,64 @@ func (r *Relay) Receive(ctx context.Context) (Envelope, error) { // Publish sends an event to the relay and waits for the OK response. func (r *Relay) Publish(ctx context.Context, event *Event) error { + ch := make(chan *OKEnvelope, 1) + + r.okChannelsMu.Lock() + r.okChannels[event.ID] = ch + r.okChannelsMu.Unlock() + + defer func() { + r.okChannelsMu.Lock() + delete(r.okChannels, event.ID) + r.okChannelsMu.Unlock() + }() + env := EventEnvelope{Event: event} if err := r.Send(ctx, env); err != nil { return fmt.Errorf("failed to send event: %w", err) } - // Wait for OK response - for { - resp, err := r.Receive(ctx) - if err != nil { - return fmt.Errorf("failed to receive response: %w", err) - } - - if ok, isOK := resp.(*OKEnvelope); isOK { - if ok.EventID == event.ID { - if !ok.OK { - return fmt.Errorf("event rejected: %s", ok.Message) - } - return nil - } + select { + case ok := <-ch: + if !ok.OK { + return fmt.Errorf("event rejected: %s", ok.Message) } - - // Dispatch other messages to subscriptions - r.dispatchEnvelope(resp) + return nil + case <-ctx.Done(): + return ctx.Err() } } -// Subscribe creates a subscription with the given filters. -func (r *Relay) Subscribe(ctx context.Context, id string, filters ...Filter) (*Subscription, error) { +func genID() string { + buf := make([]byte, 5) + rand.Read(buf) + return fmt.Sprintf("%x", buf) +} + +// subscribe is the internal implementation for Subscribe and Fetch. +func (r *Relay) subscribe(ctx context.Context, closeOnEOSE bool, filters ...Filter) *Subscription { + id := genID() + sub := &Subscription{ - ID: id, - relay: r, - Filters: filters, - Events: make(chan *Event, 100), - EOSE: make(chan struct{}, 1), - closed: make(chan struct{}), + ID: id, + relay: r, + Filters: filters, + Events: make(chan *Event, 100), + closeOnEOSE: closeOnEOSE, } r.subscriptionsMu.Lock() r.subscriptions[id] = sub r.subscriptionsMu.Unlock() + go func() { + <-ctx.Done() + sub.stop(ctx.Err()) + r.subscriptionsMu.Lock() + delete(r.subscriptions, id) + r.subscriptionsMu.Unlock() + }() + env := ReqEnvelope{ SubscriptionID: id, Filters: filters, @@ -132,10 +158,24 @@ func (r *Relay) Subscribe(ctx context.Context, id string, filters ...Filter) (*S r.subscriptionsMu.Lock() delete(r.subscriptions, id) r.subscriptionsMu.Unlock() - return nil, fmt.Errorf("failed to send subscription request: %w", err) + sub.stop(fmt.Errorf("failed to send subscription request: %w", err)) } - return sub, nil + return sub +} + +// Subscribe creates a subscription with the given filters. +// Events are received on the Events channel until the context is cancelled. +// After EOSE (end of stored events), the subscription continues to receive +// real-time events per the Nostr protocol. +func (r *Relay) Subscribe(ctx context.Context, filters ...Filter) *Subscription { + return r.subscribe(ctx, false, filters...) +} + +// Fetch creates a subscription that closes automatically when EOSE is received. +// Use this for one-shot queries where you only want stored events. +func (r *Relay) Fetch(ctx context.Context, filters ...Filter) *Subscription { + return r.subscribe(ctx, true, filters...) } // dispatchEnvelope routes incoming messages to the appropriate subscription. @@ -146,35 +186,57 @@ func (r *Relay) dispatchEnvelope(env Envelope) { sub, ok := r.subscriptions[e.SubscriptionID] r.subscriptionsMu.RUnlock() if ok { - select { - case sub.Events <- e.Event: - default: - // Channel full, drop event - } + sub.send(e.Event) } case *EOSEEnvelope: r.subscriptionsMu.RLock() sub, ok := r.subscriptions[e.SubscriptionID] r.subscriptionsMu.RUnlock() - if ok { - select { - case sub.EOSE <- struct{}{}: - default: - } + if ok && sub.closeOnEOSE { + r.subscriptionsMu.Lock() + delete(r.subscriptions, e.SubscriptionID) + r.subscriptionsMu.Unlock() + sub.stop(nil) } case *ClosedEnvelope: r.subscriptionsMu.Lock() - if sub, ok := r.subscriptions[e.SubscriptionID]; ok { - close(sub.closed) + sub, ok := r.subscriptions[e.SubscriptionID] + if ok { delete(r.subscriptions, e.SubscriptionID) } r.subscriptionsMu.Unlock() + if ok { + sub.stop(fmt.Errorf("subscription closed by relay: %s", e.Message)) + } + case *OKEnvelope: + r.okChannelsMu.Lock() + ch, ok := r.okChannels[e.EventID] + r.okChannelsMu.Unlock() + if ok { + select { + case ch <- e: + default: + } + } } } // Listen reads messages from the relay and dispatches them to subscriptions. -// This should be called in a goroutine when using multiple subscriptions. func (r *Relay) Listen(ctx context.Context) error { + defer func() { + r.subscriptionsMu.Lock() + subs := make([]*Subscription, 0, len(r.subscriptions)) + for id, sub := range r.subscriptions { + subs = append(subs, sub) + delete(r.subscriptions, id) + } + r.subscriptionsMu.Unlock() + + for _, sub := range subs { + sub.stop(fmt.Errorf("connection closed")) + } + }() + for { select { case <-ctx.Done(): @@ -197,12 +259,43 @@ type Subscription struct { relay *Relay Filters []Filter Events chan *Event - EOSE chan struct{} - closed chan struct{} + Err error + + closeOnEOSE bool + mu sync.Mutex + done bool +} + +// send delivers an event to the subscription's Events channel. +func (s *Subscription) send(ev *Event) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done { + return + } + select { + case s.Events <- ev: + default: + } +} + +// stop closes the subscription's Events channel and sets the error. +// It is idempotent — only the first call has any effect. +func (s *Subscription) stop(err error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done { + return + } + s.done = true + s.Err = err + close(s.Events) } // Close unsubscribes from the relay. func (s *Subscription) Close(ctx context.Context) error { + s.stop(nil) + s.relay.subscriptionsMu.Lock() delete(s.relay.subscriptions, s.ID) s.relay.subscriptionsMu.Unlock() @@ -210,8 +303,3 @@ func (s *Subscription) Close(ctx context.Context) error { env := CloseEnvelope{SubscriptionID: s.ID} return s.relay.Send(ctx, env) } - -// Closed returns a channel that's closed when the subscription is terminated. -func (s *Subscription) Closed() <-chan struct{} { - return s.closed -} diff --git a/relay_test.go b/relay_test.go index 4ace956..b39aa06 100644 --- a/relay_test.go +++ b/relay_test.go @@ -76,9 +76,16 @@ func TestRelaySendReceive(t *testing.T) { url := "ws" + strings.TrimPrefix(server.URL, "http") ctx := context.Background() - relay, err := Connect(ctx, url) + // Create relay without auto-Listen to test Send/Receive directly + conn, _, err := websocket.Dial(ctx, url, nil) if err != nil { - t.Fatalf("Connect() error = %v", err) + t.Fatalf("Dial() error = %v", err) + } + relay := &Relay{ + URL: url, + conn: conn, + subscriptions: make(map[string]*Subscription), + okChannels: make(map[string]chan *OKEnvelope), } defer relay.Close() @@ -234,7 +241,8 @@ func TestRelaySubscribe(t *testing.T) { defer server.Close() url := "ws" + strings.TrimPrefix(server.URL, "http") - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() relay, err := Connect(ctx, url) if err != nil { @@ -242,30 +250,18 @@ func TestRelaySubscribe(t *testing.T) { } defer relay.Close() - sub, err := relay.Subscribe(ctx, "sub1", Filter{Kinds: []int{1}}) - if err != nil { - t.Fatalf("Subscribe() error = %v", err) - } + sub := relay.Fetch(ctx, Filter{Kinds: []int{1}}) - // Start listening in background - go relay.Listen(ctx) - - // Collect events eventCount := 0 - timeout := time.After(2 * time.Second) - - for { - select { - case <-sub.Events: - eventCount++ - case <-sub.EOSE: - if eventCount != 3 { - t.Errorf("Received %d events, want 3", eventCount) - } - return - case <-timeout: - t.Fatal("Timeout waiting for events") - } + for range sub.Events { + eventCount++ + } + + if eventCount != 3 { + t.Errorf("Received %d events, want 3", eventCount) + } + if sub.Err != nil { + t.Errorf("Subscription.Err = %v, want nil", sub.Err) } } @@ -322,10 +318,7 @@ func TestSubscriptionClose(t *testing.T) { } defer relay.Close() - sub, err := relay.Subscribe(ctx, "sub1", Filter{Kinds: []int{1}}) - if err != nil { - t.Fatalf("Subscribe() error = %v", err) - } + sub := relay.Subscribe(ctx, Filter{Kinds: []int{1}}) if err := sub.Close(ctx); err != nil { t.Errorf("Subscription.Close() error = %v", err) -- cgit v1.2.3