From d744c32f1bc7411e04c97a9d14c172baaa0e4a89 Mon Sep 17 00:00:00 2001 From: bndw Date: Sun, 15 Feb 2026 10:06:18 -0800 Subject: test: add integration tests for NIP-42 AUTH and rate limiting Add comprehensive WebSocket handler integration tests that verify: - NIP-42 authentication flow (auth required, challenge/response) - Allowlist enforcement (reject unauthorized pubkeys) - Rate limiting by IP address - Rate limiting by authenticated pubkey - No-auth mode works correctly These tests use real WebSocket connections and would have caught the AUTH timeout bug and other protocol issues. Tests cover: - TestAuthRequired: Verifies AUTH challenge sent, client authenticates, publish succeeds - TestAuthNotInAllowlist: Verifies pubkeys not in allowlist are rejected - TestRateLimitByIP: Verifies unauthenticated clients are rate limited by IP - TestRateLimitByPubkey: Verifies authenticated clients are rate limited by pubkey - TestNoAuthWhenDisabled: Verifies publishing works when auth is disabled --- internal/handler/websocket/handler_test.go | 526 +++++++++++++++++++++++++++++ 1 file changed, 526 insertions(+) create mode 100644 internal/handler/websocket/handler_test.go (limited to 'internal/handler/websocket/handler_test.go') diff --git a/internal/handler/websocket/handler_test.go b/internal/handler/websocket/handler_test.go new file mode 100644 index 0000000..9f02510 --- /dev/null +++ b/internal/handler/websocket/handler_test.go @@ -0,0 +1,526 @@ +package websocket + +import ( + "context" + "encoding/json" + "fmt" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "northwest.io/muxstr/internal/ratelimit" + "northwest.io/muxstr/internal/storage" + "northwest.io/muxstr/internal/subscription" + ws "northwest.io/muxstr/internal/websocket" + pb "northwest.io/muxstr/api/nostr/v1" + + "fiatjaf.com/nostr" +) + +// mockAuthStore implements the auth methods needed for testing +type mockAuthStore struct { + challenges map[string]time.Time + mu sync.Mutex +} + +func newMockAuthStore() *mockAuthStore { + return &mockAuthStore{ + challenges: make(map[string]time.Time), + } +} + +func (m *mockAuthStore) CreateAuthChallenge(ctx context.Context) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + challenge := fmt.Sprintf("test-challenge-%d", time.Now().UnixNano()) + m.challenges[challenge] = time.Now() + return challenge, nil +} + +func (m *mockAuthStore) ValidateAndConsumeChallenge(ctx context.Context, challenge string) error { + m.mu.Lock() + defer m.mu.Unlock() + if _, exists := m.challenges[challenge]; !exists { + return fmt.Errorf("invalid challenge") + } + delete(m.challenges, challenge) + return nil +} + +func (m *mockAuthStore) StoreEvent(ctx context.Context, event *storage.EventData) error { + return nil +} + +func (m *mockAuthStore) QueryEvents(ctx context.Context, filters []*pb.Filter, opts *storage.QueryOptions) ([]*pb.Event, error) { + return []*pb.Event{}, nil +} + +func (m *mockAuthStore) ProcessDeletion(ctx context.Context, event *pb.Event) error { + return nil +} + +// mockMetrics implements metrics recording for testing +type mockMetrics struct { + mu sync.Mutex + connections int + requests map[string]int + blockedEvents map[int32]int +} + +func newMockMetrics() *mockMetrics { + return &mockMetrics{ + requests: make(map[string]int), + blockedEvents: make(map[int32]int), + } +} + +func (m *mockMetrics) IncrementConnections() { + m.mu.Lock() + defer m.mu.Unlock() + m.connections++ +} + +func (m *mockMetrics) DecrementConnections() { + m.mu.Lock() + defer m.mu.Unlock() + m.connections-- +} + +func (m *mockMetrics) IncrementSubscriptions() {} +func (m *mockMetrics) DecrementSubscriptions() {} +func (m *mockMetrics) SetActiveSubscriptions(count int) {} +func (m *mockMetrics) RecordRequest(method, status string, duration float64) { + m.mu.Lock() + defer m.mu.Unlock() + key := fmt.Sprintf("%s:%s", method, status) + m.requests[key]++ +} +func (m *mockMetrics) RecordBlockedEvent(kind int32) { + m.mu.Lock() + defer m.mu.Unlock() + m.blockedEvents[kind]++ +} + +func (m *mockMetrics) getRequestCount(method, status string) int { + m.mu.Lock() + defer m.mu.Unlock() + return m.requests[fmt.Sprintf("%s:%s", method, status)] +} + +// testServer sets up a test WebSocket server with the handler +type testServer struct { + server *httptest.Server + handler *Handler + store *mockAuthStore + metrics *mockMetrics + limiter *ratelimit.Limiter +} + +func newTestServer(authConfig *AuthConfig, enableRateLimit bool) *testServer { + store := newMockAuthStore() + metrics := newMockMetrics() + subs := subscription.NewManager() + + handler := NewHandler(store, subs) + handler.SetMetrics(metrics) + + if authConfig != nil { + handler.SetAuth(store) + handler.SetAuthConfig(authConfig) + } + + var limiter *ratelimit.Limiter + if enableRateLimit { + limiter = ratelimit.New(&ratelimit.Config{ + RequestsPerSecond: 2, // Low limit for easy testing + BurstSize: 2, + }) + handler.SetRateLimiter(limiter) + } + + server := httptest.NewServer(handler) + + return &testServer{ + server: server, + handler: handler, + store: store, + metrics: metrics, + limiter: limiter, + } +} + +func (ts *testServer) Close() { + ts.server.Close() +} + +func (ts *testServer) wsURL() string { + return "ws" + strings.TrimPrefix(ts.server.URL, "http") +} + +// connectWS creates a WebSocket connection to the test server +func (ts *testServer) connectWS(t *testing.T) *ws.Conn { + conn, err := ws.Dial(context.Background(), ts.wsURL()) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + return conn +} + +// sendMessage sends a JSON message over WebSocket +func sendMessage(t *testing.T, conn *ws.Conn, msg interface{}) { + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Failed to marshal message: %v", err) + } + if err := conn.Write(context.Background(), ws.MessageText, data); err != nil { + t.Fatalf("Failed to send message: %v", err) + } +} + +// sendEvent sends an EVENT message and returns the response +func sendEvent(t *testing.T, conn *ws.Conn, event *nostr.Event) []interface{} { + msg := []interface{}{"EVENT", event} + sendMessage(t, conn, msg) + return readMessage(t, conn) +} + +// sendAuth sends an AUTH message +func sendAuth(t *testing.T, conn *ws.Conn, authEvent *nostr.Event) { + msg := []interface{}{"AUTH", authEvent} + sendMessage(t, conn, msg) +} + +// readMessage reads a message from the WebSocket +func readMessage(t *testing.T, conn *ws.Conn) []interface{} { + _, data, err := conn.Read(context.Background()) + if err != nil { + t.Fatalf("Failed to read message: %v", err) + } + var msg []interface{} + if err := json.Unmarshal(data, &msg); err != nil { + t.Fatalf("Failed to unmarshal message: %v", err) + } + return msg +} + +// readMessageWithTimeout reads a message with timeout +func readMessageWithTimeout(conn *ws.Conn, timeout time.Duration) ([]interface{}, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + _, data, err := conn.Read(ctx) + if err != nil { + return nil, err + } + + var msg []interface{} + if err := json.Unmarshal(data, &msg); err != nil { + return nil, err + } + return msg, nil +} + +// createTestEvent creates a signed test event +func createTestEvent(sk nostr.SecretKey, content string) *nostr.Event { + event := &nostr.Event{ + PubKey: nostr.GetPublicKey(sk), + CreatedAt: nostr.Now(), + Kind: nostr.KindTextNote, + Tags: nostr.Tags{}, + Content: content, + } + event.Sign(sk) + return event +} + +// TestAuthRequired verifies that AUTH is required when configured +func TestAuthRequired(t *testing.T) { + sk := nostr.Generate() + pubkey := nostr.GetPublicKey(sk) + + authConfig := &AuthConfig{ + WriteEnabled: true, + WriteAllowedPubkeys: []string{fmt.Sprintf("%x", pubkey[:])}, + } + + ts := newTestServer(authConfig, false) + defer ts.Close() + + conn := ts.connectWS(t) + defer conn.Close(ws.StatusNormalClosure, "test done") + + // Try to publish without auth + event := createTestEvent(sk, "test without auth") + + // Send EVENT + sendMessage(t, conn, []interface{}{"EVENT", event}) + + // Should receive AUTH challenge + msg1 := readMessage(t, conn) + if len(msg1) < 2 || msg1[0] != "AUTH" { + t.Fatalf("Expected AUTH challenge, got: %v", msg1) + } + challenge := msg1[1].(string) + t.Logf("Received AUTH challenge: %s", challenge) + + // Should also receive OK false + msg2 := readMessage(t, conn) + if len(msg2) < 4 || msg2[0] != "OK" { + t.Fatalf("Expected OK message, got: %v", msg2) + } + if msg2[2].(bool) != false { + t.Errorf("Expected OK false, got true") + } + if !strings.Contains(msg2[3].(string), "auth-required") { + t.Errorf("Expected 'auth-required' message, got: %s", msg2[3]) + } + t.Logf("Received OK false: %v", msg2[3]) + + // Now authenticate + authEvent := &nostr.Event{ + PubKey: pubkey, + CreatedAt: nostr.Now(), + Kind: 22242, + Tags: nostr.Tags{ + {"relay", ts.server.URL}, + {"challenge", challenge}, + }, + Content: "", + } + authEvent.Sign(sk) + sendAuth(t, conn, authEvent) + + // Retry the EVENT + event2 := createTestEvent(sk, "test with auth") + msg3 := sendEvent(t, conn, event2) + + // Should now succeed + if len(msg3) < 4 || msg3[0] != "OK" { + t.Fatalf("Expected OK message, got: %v", msg3) + } + if msg3[2].(bool) != true { + t.Errorf("Expected OK true after auth, got false: %v", msg3[3]) + } + t.Logf("Publish succeeded after auth") +} + +// TestAuthNotInAllowlist verifies that pubkeys not in allowlist are rejected +func TestAuthNotInAllowlist(t *testing.T) { + allowedSk := nostr.Generate() + allowedPubkey := nostr.GetPublicKey(allowedSk) + + unauthorizedSk := nostr.Generate() + + authConfig := &AuthConfig{ + WriteEnabled: true, + WriteAllowedPubkeys: []string{fmt.Sprintf("%x", allowedPubkey[:])}, + } + + ts := newTestServer(authConfig, false) + defer ts.Close() + + conn := ts.connectWS(t) + defer conn.Close(ws.StatusNormalClosure, "test done") + + event := createTestEvent(unauthorizedSk, "unauthorized test") + + // Send EVENT + sendMessage(t, conn, []interface{}{"EVENT", event}) + + // Receive AUTH challenge + msg1 := readMessage(t, conn) + if len(msg1) < 2 || msg1[0] != "AUTH" { + t.Fatalf("Expected AUTH challenge, got: %v", msg1) + } + challenge := msg1[1].(string) + + // Receive OK false (auth required) + msg2 := readMessage(t, conn) + if msg2[0] != "OK" || msg2[2].(bool) != false { + t.Fatalf("Expected OK false, got: %v", msg2) + } + + // Authenticate with unauthorized key + authEvent := &nostr.Event{ + PubKey: nostr.GetPublicKey(unauthorizedSk), + CreatedAt: nostr.Now(), + Kind: 22242, + Tags: nostr.Tags{ + {"relay", ts.server.URL}, + {"challenge", challenge}, + }, + Content: "", + } + authEvent.Sign(unauthorizedSk) + sendAuth(t, conn, authEvent) + + // Retry EVENT with unauthorized key + event2 := createTestEvent(unauthorizedSk, "retry unauthorized") + msg3 := sendEvent(t, conn, event2) + + // Should be rejected - not in allowlist + if len(msg3) < 3 || msg3[0] != "OK" { + t.Fatalf("Expected OK message, got: %v", msg3) + } + if msg3[2].(bool) != false { + t.Errorf("Expected OK false for unauthorized pubkey, got false: %v", msg3[3]) + } + t.Logf("Unauthorized pubkey correctly rejected: %v", msg3[3]) +} + +// TestRateLimitByIP verifies that rate limiting works by IP +func TestRateLimitByIP(t *testing.T) { + ts := newTestServer(nil, true) // No auth, but rate limiting enabled + defer ts.Close() + + conn := ts.connectWS(t) + defer conn.Close(ws.StatusNormalClosure, "test done") + + sk := nostr.Generate() + + // Rate limit is 2 req/sec with burst 2 + // So 3rd request should be blocked + + successCount := 0 + rateLimitCount := 0 + + for i := 0; i < 5; i++ { + event := createTestEvent(sk, fmt.Sprintf("test event %d", i)) + msg := sendEvent(t, conn, event) + + if len(msg) < 3 || msg[0] != "OK" { + t.Fatalf("Expected OK message, got: %v", msg) + } + + if msg[2].(bool) { + successCount++ + t.Logf("Event %d: accepted", i) + } else { + rateLimitCount++ + msgStr := "" + if len(msg) > 3 { + msgStr = msg[3].(string) + } + if !strings.Contains(msgStr, "rate-limited") { + t.Errorf("Expected 'rate-limited' message, got: %v", msgStr) + } + t.Logf("Event %d: rate limited - %v", i, msgStr) + } + + time.Sleep(10 * time.Millisecond) + } + + if successCount < 2 { + t.Errorf("Expected at least 2 successful requests (burst), got %d", successCount) + } + + if rateLimitCount == 0 { + t.Errorf("Expected some requests to be rate limited, got 0") + } + + t.Logf("Rate limiting working: %d accepted, %d rate limited", successCount, rateLimitCount) +} + +// TestRateLimitByPubkey verifies that rate limiting works by authenticated pubkey +func TestRateLimitByPubkey(t *testing.T) { + sk := nostr.Generate() + pubkey := nostr.GetPublicKey(sk) + + authConfig := &AuthConfig{ + WriteEnabled: true, + WriteAllowedPubkeys: []string{fmt.Sprintf("%x", pubkey[:])}, + } + + ts := newTestServer(authConfig, true) // Auth + rate limiting + defer ts.Close() + + conn := ts.connectWS(t) + defer conn.Close(ws.StatusNormalClosure, "test done") + + // Authenticate first + event := createTestEvent(sk, "trigger auth") + sendMessage(t, conn, []interface{}{"EVENT", event}) + + // Get AUTH challenge + msg1 := readMessage(t, conn) + if msg1[0] != "AUTH" { + t.Fatalf("Expected AUTH, got: %v", msg1) + } + challenge := msg1[1].(string) + + // Read OK false + readMessage(t, conn) + + // Send AUTH + authEvent := &nostr.Event{ + PubKey: pubkey, + CreatedAt: nostr.Now(), + Kind: 22242, + Tags: nostr.Tags{ + {"relay", ts.server.URL}, + {"challenge", challenge}, + }, + Content: "", + } + authEvent.Sign(sk) + sendAuth(t, conn, authEvent) + + // Now spam events - should be rate limited by pubkey + successCount := 0 + rateLimitCount := 0 + + for i := 0; i < 5; i++ { + event := createTestEvent(sk, fmt.Sprintf("spam %d", i)) + msg := sendEvent(t, conn, event) + + if len(msg) < 3 || msg[0] != "OK" { + t.Fatalf("Expected OK, got: %v", msg) + } + + if msg[2].(bool) { + successCount++ + t.Logf("Event %d: accepted", i) + } else { + rateLimitCount++ + msgStr := "" + if len(msg) > 3 { + msgStr = msg[3].(string) + } + t.Logf("Event %d: rate limited - %v", i, msgStr) + } + + time.Sleep(10 * time.Millisecond) + } + + if rateLimitCount == 0 { + t.Errorf("Expected rate limiting by pubkey, but all requests succeeded") + } + + t.Logf("Rate limiting by pubkey working: %d accepted, %d rate limited", successCount, rateLimitCount) +} + +// TestNoAuthWhenDisabled verifies that publishing works without auth when auth is disabled +func TestNoAuthWhenDisabled(t *testing.T) { + ts := newTestServer(nil, false) // No auth, no rate limiting + defer ts.Close() + + conn := ts.connectWS(t) + defer conn.Close(ws.StatusNormalClosure, "test done") + + sk := nostr.Generate() + event := createTestEvent(sk, "test without auth required") + + msg := sendEvent(t, conn, event) + + if len(msg) < 3 || msg[0] != "OK" { + t.Fatalf("Expected OK message, got: %v", msg) + } + + if msg[2].(bool) != true { + t.Errorf("Expected OK true when auth disabled, got false: %v", msg[3]) + } + + t.Logf("Publishing without auth succeeded as expected") +} -- cgit v1.2.3