package websocket import ( "context" "encoding/json" "fmt" "net/http/httptest" "strings" "sync" "testing" "time" "northwest.io/muxstr/internal/ratelimit" "northwest.io/muxstr/internal/storage" "northwest.io/muxstr/internal/subscription" ws "northwest.io/muxstr/internal/websocket" pb "northwest.io/muxstr/api/nostr/v1" "fiatjaf.com/nostr" ) // mockAuthStore implements the auth methods needed for testing type mockAuthStore struct { challenges map[string]time.Time mu sync.Mutex } func newMockAuthStore() *mockAuthStore { return &mockAuthStore{ challenges: make(map[string]time.Time), } } func (m *mockAuthStore) CreateAuthChallenge(ctx context.Context) (string, error) { m.mu.Lock() defer m.mu.Unlock() challenge := fmt.Sprintf("test-challenge-%d", time.Now().UnixNano()) m.challenges[challenge] = time.Now() return challenge, nil } func (m *mockAuthStore) ValidateAndConsumeChallenge(ctx context.Context, challenge string) error { m.mu.Lock() defer m.mu.Unlock() if _, exists := m.challenges[challenge]; !exists { return fmt.Errorf("invalid challenge") } delete(m.challenges, challenge) return nil } func (m *mockAuthStore) StoreEvent(ctx context.Context, event *storage.EventData) error { return nil } func (m *mockAuthStore) QueryEvents(ctx context.Context, filters []*pb.Filter, opts *storage.QueryOptions) ([]*pb.Event, error) { return []*pb.Event{}, nil } func (m *mockAuthStore) ProcessDeletion(ctx context.Context, event *pb.Event) error { return nil } // mockMetrics implements metrics recording for testing type mockMetrics struct { mu sync.Mutex connections int requests map[string]int blockedEvents map[int32]int } func newMockMetrics() *mockMetrics { return &mockMetrics{ requests: make(map[string]int), blockedEvents: make(map[int32]int), } } func (m *mockMetrics) IncrementConnections() { m.mu.Lock() defer m.mu.Unlock() m.connections++ } func (m *mockMetrics) DecrementConnections() { m.mu.Lock() defer m.mu.Unlock() m.connections-- } func (m *mockMetrics) IncrementSubscriptions() {} func (m *mockMetrics) DecrementSubscriptions() {} func (m *mockMetrics) SetActiveSubscriptions(count int) {} func (m *mockMetrics) RecordRequest(method, status string, duration float64) { m.mu.Lock() defer m.mu.Unlock() key := fmt.Sprintf("%s:%s", method, status) m.requests[key]++ } func (m *mockMetrics) RecordAuthAttempt(success bool) {} func (m *mockMetrics) RecordBlockedEvent(kind int32) { m.mu.Lock() defer m.mu.Unlock() m.blockedEvents[kind]++ } func (m *mockMetrics) getRequestCount(method, status string) int { m.mu.Lock() defer m.mu.Unlock() return m.requests[fmt.Sprintf("%s:%s", method, status)] } // testServer sets up a test WebSocket server with the handler type testServer struct { server *httptest.Server handler *Handler store *mockAuthStore metrics *mockMetrics limiter *ratelimit.Limiter } func newTestServer(authConfig *AuthConfig, enableRateLimit bool) *testServer { store := newMockAuthStore() metrics := newMockMetrics() subs := subscription.NewManager() handler := NewHandler(store, subs) handler.SetMetrics(metrics) if authConfig != nil { handler.SetAuth(store) handler.SetAuthConfig(authConfig) } var limiter *ratelimit.Limiter if enableRateLimit { limiter = ratelimit.New(&ratelimit.Config{ RequestsPerSecond: 2, // Low limit for easy testing BurstSize: 2, }) handler.SetRateLimiter(limiter) } server := httptest.NewServer(handler) return &testServer{ server: server, handler: handler, store: store, metrics: metrics, limiter: limiter, } } func (ts *testServer) Close() { ts.server.Close() } func (ts *testServer) wsURL() string { return "ws" + strings.TrimPrefix(ts.server.URL, "http") } // connectWS creates a WebSocket connection to the test server func (ts *testServer) connectWS(t *testing.T) *ws.Conn { conn, err := ws.Dial(context.Background(), ts.wsURL()) if err != nil { t.Fatalf("Failed to connect: %v", err) } return conn } // sendMessage sends a JSON message over WebSocket func sendMessage(t *testing.T, conn *ws.Conn, msg interface{}) { data, err := json.Marshal(msg) if err != nil { t.Fatalf("Failed to marshal message: %v", err) } if err := conn.Write(context.Background(), ws.MessageText, data); err != nil { t.Fatalf("Failed to send message: %v", err) } } // sendEvent sends an EVENT message and returns the response func sendEvent(t *testing.T, conn *ws.Conn, event *nostr.Event) []interface{} { msg := []interface{}{"EVENT", event} sendMessage(t, conn, msg) return readMessage(t, conn) } // sendAuth sends an AUTH message func sendAuth(t *testing.T, conn *ws.Conn, authEvent *nostr.Event) { msg := []interface{}{"AUTH", authEvent} sendMessage(t, conn, msg) } // readMessage reads a message from the WebSocket func readMessage(t *testing.T, conn *ws.Conn) []interface{} { _, data, err := conn.Read(context.Background()) if err != nil { t.Fatalf("Failed to read message: %v", err) } var msg []interface{} if err := json.Unmarshal(data, &msg); err != nil { t.Fatalf("Failed to unmarshal message: %v", err) } return msg } // readMessageWithTimeout reads a message with timeout func readMessageWithTimeout(conn *ws.Conn, timeout time.Duration) ([]interface{}, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() _, data, err := conn.Read(ctx) if err != nil { return nil, err } var msg []interface{} if err := json.Unmarshal(data, &msg); err != nil { return nil, err } return msg, nil } // createTestEvent creates a signed test event func createTestEvent(sk nostr.SecretKey, content string) *nostr.Event { event := &nostr.Event{ PubKey: nostr.GetPublicKey(sk), CreatedAt: nostr.Now(), Kind: nostr.KindTextNote, Tags: nostr.Tags{}, Content: content, } event.Sign(sk) return event } // TestAuthRequired verifies that AUTH is required when configured func TestAuthRequired(t *testing.T) { sk := nostr.Generate() pubkey := nostr.GetPublicKey(sk) authConfig := &AuthConfig{ WriteEnabled: true, WriteAllowedPubkeys: []string{fmt.Sprintf("%x", pubkey[:])}, } ts := newTestServer(authConfig, false) defer ts.Close() conn := ts.connectWS(t) defer conn.Close(ws.StatusNormalClosure, "test done") // Try to publish without auth event := createTestEvent(sk, "test without auth") // Send EVENT sendMessage(t, conn, []interface{}{"EVENT", event}) // Should receive AUTH challenge msg1 := readMessage(t, conn) if len(msg1) < 2 || msg1[0] != "AUTH" { t.Fatalf("Expected AUTH challenge, got: %v", msg1) } challenge := msg1[1].(string) t.Logf("Received AUTH challenge: %s", challenge) // Should also receive OK false msg2 := readMessage(t, conn) if len(msg2) < 4 || msg2[0] != "OK" { t.Fatalf("Expected OK message, got: %v", msg2) } if msg2[2].(bool) != false { t.Errorf("Expected OK false, got true") } if !strings.Contains(msg2[3].(string), "auth-required") { t.Errorf("Expected 'auth-required' message, got: %s", msg2[3]) } t.Logf("Received OK false: %v", msg2[3]) // Now authenticate authEvent := &nostr.Event{ PubKey: pubkey, CreatedAt: nostr.Now(), Kind: 22242, Tags: nostr.Tags{ {"relay", ts.server.URL}, {"challenge", challenge}, }, Content: "", } authEvent.Sign(sk) sendAuth(t, conn, authEvent) // Retry the EVENT event2 := createTestEvent(sk, "test with auth") msg3 := sendEvent(t, conn, event2) // Should now succeed if len(msg3) < 4 || msg3[0] != "OK" { t.Fatalf("Expected OK message, got: %v", msg3) } if msg3[2].(bool) != true { t.Errorf("Expected OK true after auth, got false: %v", msg3[3]) } t.Logf("Publish succeeded after auth") // Verify authorized requests are tracked in metrics authorizedCount := ts.metrics.getRequestCount("EVENT", "authorized") if authorizedCount == 0 { t.Errorf("Expected authorized requests to be tracked in metrics, got 0") } t.Logf("Metrics: %d authorized requests tracked", authorizedCount) } // TestAuthNotInAllowlist verifies that pubkeys not in allowlist are rejected func TestAuthNotInAllowlist(t *testing.T) { allowedSk := nostr.Generate() allowedPubkey := nostr.GetPublicKey(allowedSk) unauthorizedSk := nostr.Generate() authConfig := &AuthConfig{ WriteEnabled: true, WriteAllowedPubkeys: []string{fmt.Sprintf("%x", allowedPubkey[:])}, } ts := newTestServer(authConfig, false) defer ts.Close() conn := ts.connectWS(t) defer conn.Close(ws.StatusNormalClosure, "test done") event := createTestEvent(unauthorizedSk, "unauthorized test") // Send EVENT sendMessage(t, conn, []interface{}{"EVENT", event}) // Receive AUTH challenge msg1 := readMessage(t, conn) if len(msg1) < 2 || msg1[0] != "AUTH" { t.Fatalf("Expected AUTH challenge, got: %v", msg1) } challenge := msg1[1].(string) // Receive OK false (auth required) msg2 := readMessage(t, conn) if msg2[0] != "OK" || msg2[2].(bool) != false { t.Fatalf("Expected OK false, got: %v", msg2) } // Authenticate with unauthorized key authEvent := &nostr.Event{ PubKey: nostr.GetPublicKey(unauthorizedSk), CreatedAt: nostr.Now(), Kind: 22242, Tags: nostr.Tags{ {"relay", ts.server.URL}, {"challenge", challenge}, }, Content: "", } authEvent.Sign(unauthorizedSk) sendAuth(t, conn, authEvent) // Retry EVENT with unauthorized key event2 := createTestEvent(unauthorizedSk, "retry unauthorized") msg3 := sendEvent(t, conn, event2) // Should be rejected - not in allowlist if len(msg3) < 3 || msg3[0] != "OK" { t.Fatalf("Expected OK message, got: %v", msg3) } if msg3[2].(bool) != false { t.Errorf("Expected OK false for unauthorized pubkey, got false: %v", msg3[3]) } t.Logf("Unauthorized pubkey correctly rejected: %v", msg3[3]) // Verify metrics tracked the unauthorized request unauthorizedCount := ts.metrics.getRequestCount("EVENT", "unauthorized") if unauthorizedCount == 0 { t.Errorf("Expected unauthorized requests to be tracked in metrics, got 0") } t.Logf("Metrics: %d unauthorized requests tracked", unauthorizedCount) } // TestRateLimitByIP verifies that rate limiting works by IP func TestRateLimitByIP(t *testing.T) { ts := newTestServer(nil, true) // No auth, but rate limiting enabled defer ts.Close() conn := ts.connectWS(t) defer conn.Close(ws.StatusNormalClosure, "test done") sk := nostr.Generate() // Rate limit is 2 req/sec with burst 2 // So 3rd request should be blocked successCount := 0 rateLimitCount := 0 for i := 0; i < 5; i++ { event := createTestEvent(sk, fmt.Sprintf("test event %d", i)) msg := sendEvent(t, conn, event) if len(msg) < 3 || msg[0] != "OK" { t.Fatalf("Expected OK message, got: %v", msg) } if msg[2].(bool) { successCount++ t.Logf("Event %d: accepted", i) } else { rateLimitCount++ msgStr := "" if len(msg) > 3 { msgStr = msg[3].(string) } if !strings.Contains(msgStr, "rate-limited") { t.Errorf("Expected 'rate-limited' message, got: %v", msgStr) } t.Logf("Event %d: rate limited - %v", i, msgStr) } time.Sleep(10 * time.Millisecond) } if successCount < 2 { t.Errorf("Expected at least 2 successful requests (burst), got %d", successCount) } if rateLimitCount == 0 { t.Errorf("Expected some requests to be rate limited, got 0") } t.Logf("Rate limiting working: %d accepted, %d rate limited", successCount, rateLimitCount) } // TestRateLimitByPubkey verifies that rate limiting works by authenticated pubkey func TestRateLimitByPubkey(t *testing.T) { sk := nostr.Generate() pubkey := nostr.GetPublicKey(sk) authConfig := &AuthConfig{ WriteEnabled: true, WriteAllowedPubkeys: []string{fmt.Sprintf("%x", pubkey[:])}, } ts := newTestServer(authConfig, true) // Auth + rate limiting defer ts.Close() conn := ts.connectWS(t) defer conn.Close(ws.StatusNormalClosure, "test done") // Authenticate first event := createTestEvent(sk, "trigger auth") sendMessage(t, conn, []interface{}{"EVENT", event}) // Get AUTH challenge msg1 := readMessage(t, conn) if msg1[0] != "AUTH" { t.Fatalf("Expected AUTH, got: %v", msg1) } challenge := msg1[1].(string) // Read OK false readMessage(t, conn) // Send AUTH authEvent := &nostr.Event{ PubKey: pubkey, CreatedAt: nostr.Now(), Kind: 22242, Tags: nostr.Tags{ {"relay", ts.server.URL}, {"challenge", challenge}, }, Content: "", } authEvent.Sign(sk) sendAuth(t, conn, authEvent) // Now spam events - should be rate limited by pubkey successCount := 0 rateLimitCount := 0 for i := 0; i < 5; i++ { event := createTestEvent(sk, fmt.Sprintf("spam %d", i)) msg := sendEvent(t, conn, event) if len(msg) < 3 || msg[0] != "OK" { t.Fatalf("Expected OK, got: %v", msg) } if msg[2].(bool) { successCount++ t.Logf("Event %d: accepted", i) } else { rateLimitCount++ msgStr := "" if len(msg) > 3 { msgStr = msg[3].(string) } t.Logf("Event %d: rate limited - %v", i, msgStr) } time.Sleep(10 * time.Millisecond) } if rateLimitCount == 0 { t.Errorf("Expected rate limiting by pubkey, but all requests succeeded") } t.Logf("Rate limiting by pubkey working: %d accepted, %d rate limited", successCount, rateLimitCount) } // TestNoAuthWhenDisabled verifies that publishing works without auth when auth is disabled func TestNoAuthWhenDisabled(t *testing.T) { ts := newTestServer(nil, false) // No auth, no rate limiting defer ts.Close() conn := ts.connectWS(t) defer conn.Close(ws.StatusNormalClosure, "test done") sk := nostr.Generate() event := createTestEvent(sk, "test without auth required") msg := sendEvent(t, conn, event) if len(msg) < 3 || msg[0] != "OK" { t.Fatalf("Expected OK message, got: %v", msg) } if msg[2].(bool) != true { t.Errorf("Expected OK true when auth disabled, got false: %v", msg[3]) } t.Logf("Publishing without auth succeeded as expected") }