package auth import ( "context" "encoding/base64" "encoding/json" "testing" "time" "google.golang.org/grpc/metadata" "northwest.io/muxstr/internal/nostr" ) func TestNostrCredentials(t *testing.T) { key, err := nostr.GenerateKey() if err != nil { t.Fatalf("failed to generate key: %v", err) } creds := NewNostrCredentials(key) // Test GetRequestMetadata ctx := context.Background() uri := "https://example.com/nostr.v1.NostrRelay/PublishEvent" md, err := creds.GetRequestMetadata(ctx, uri) if err != nil { t.Fatalf("GetRequestMetadata failed: %v", err) } // Check authorization header exists authHeader, ok := md["authorization"] if !ok { t.Fatal("missing authorization header") } // Parse and validate the event event, err := ParseAuthHeader(authHeader) if err != nil { t.Fatalf("failed to parse auth header: %v", err) } if event.Kind != 27235 { t.Errorf("wrong event kind: got %d, want 27235", event.Kind) } if event.PubKey != key.Public() { t.Error("pubkey mismatch") } if !event.Verify() { t.Error("event signature verification failed") } // Check tags uTag := event.Tags.Find("u") if uTag == nil { t.Fatal("missing 'u' tag") } if uTag.Value() != uri { t.Errorf("wrong URI in tag: got %s, want %s", uTag.Value(), uri) } methodTag := event.Tags.Find("method") if methodTag == nil { t.Fatal("missing 'method' tag") } if methodTag.Value() != "POST" { t.Errorf("wrong method in tag: got %s, want POST", methodTag.Value()) } } func TestParseAuthHeader(t *testing.T) { tests := []struct { name string header string wantErr bool }{ { name: "empty header", header: "", wantErr: true, }, { name: "missing prefix", header: "Bearer token", wantErr: true, }, { name: "invalid base64", header: "Nostr not-base64!", wantErr: true, }, { name: "invalid json", header: "Nostr " + base64.StdEncoding.EncodeToString([]byte("not json")), wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := ParseAuthHeader(tt.header) if (err != nil) != tt.wantErr { t.Errorf("ParseAuthHeader() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestValidateAuthEvent(t *testing.T) { key, _ := nostr.GenerateKey() // Create a valid event event := &nostr.Event{ PubKey: key.Public(), CreatedAt: time.Now().Unix(), Kind: 27235, Tags: nostr.Tags{ {"u", "https://example.com/test"}, {"method", "POST"}, }, Content: "", } key.Sign(event) tests := []struct { name string event *nostr.Event opts ValidationOptions wantErr bool }{ { name: "valid event", event: event, opts: ValidationOptions{ TimestampWindow: 60, ExpectedURI: "https://example.com/test", ExpectedMethod: "POST", }, wantErr: false, }, { name: "wrong kind", event: &nostr.Event{ Kind: 1, CreatedAt: time.Now().Unix(), Tags: nostr.Tags{}, }, opts: ValidationOptions{}, wantErr: true, }, { name: "old timestamp", event: &nostr.Event{ PubKey: key.Public(), CreatedAt: time.Now().Unix() - 120, // 2 minutes ago Kind: 27235, Tags: nostr.Tags{}, Sig: event.Sig, }, opts: ValidationOptions{ TimestampWindow: 60, // Only accept 60 seconds }, wantErr: true, }, { name: "URI mismatch", event: event, opts: ValidationOptions{ TimestampWindow: 60, ExpectedURI: "https://different.com/test", }, wantErr: true, }, { name: "method mismatch", event: event, opts: ValidationOptions{ TimestampWindow: 60, ExpectedMethod: "GET", }, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := ValidateAuthEvent(tt.event, tt.opts) if (err != nil) != tt.wantErr { t.Errorf("ValidateAuthEvent() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestPubkeyFromContext(t *testing.T) { ctx := context.Background() // Test empty context pubkey, ok := PubkeyFromContext(ctx) if ok { t.Error("expected ok=false for empty context") } if pubkey != "" { t.Error("expected empty pubkey for empty context") } // Test context with pubkey expectedPubkey := "test-pubkey-123" ctx = context.WithValue(ctx, pubkeyContextKey, expectedPubkey) pubkey, ok = PubkeyFromContext(ctx) if !ok { t.Error("expected ok=true for context with pubkey") } if pubkey != expectedPubkey { t.Errorf("got pubkey %s, want %s", pubkey, expectedPubkey) } } func TestValidateAuthFromContext(t *testing.T) { key, _ := nostr.GenerateKey() // Create valid auth event event := &nostr.Event{ PubKey: key.Public(), CreatedAt: time.Now().Unix(), Kind: 27235, Tags: nostr.Tags{ {"u", "https://example.com/test"}, {"method", "POST"}, }, Content: "", } key.Sign(event) eventJSON, _ := json.Marshal(event) authHeader := "Nostr " + base64.StdEncoding.EncodeToString(eventJSON) // Create context with metadata md := metadata.Pairs("authorization", authHeader) ctx := metadata.NewIncomingContext(context.Background(), md) opts := &InterceptorOptions{ TimestampWindow: 60, Required: true, } pubkey, err := validateAuthFromContext(ctx, "/test.Service/Method", opts) if err != nil { t.Fatalf("validateAuthFromContext failed: %v", err) } if pubkey != key.Public() { t.Errorf("got pubkey %s, want %s", pubkey, key.Public()) } } func TestShouldSkipAuth(t *testing.T) { skipMethods := []string{ "/health/Check", "/nostr.v1.NostrRelay/GetInfo", } tests := []struct { method string want bool }{ {"/health/Check", true}, {"/nostr.v1.NostrRelay/GetInfo", true}, {"/nostr.v1.NostrRelay/PublishEvent", false}, {"/other/Method", false}, } for _, tt := range tests { t.Run(tt.method, func(t *testing.T) { got := shouldSkipAuth(tt.method, skipMethods) if got != tt.want { t.Errorf("shouldSkipAuth(%s) = %v, want %v", tt.method, got, tt.want) } }) } } func TestHashPayload(t *testing.T) { payload := []byte("test payload") hash := HashPayload(payload) // Should be a 64-character hex string (SHA256) if len(hash) != 64 { t.Errorf("hash length = %d, want 64", len(hash)) } // Same payload should produce same hash hash2 := HashPayload(payload) if hash != hash2 { t.Error("same payload produced different hashes") } // Different payload should produce different hash hash3 := HashPayload([]byte("different payload")) if hash == hash3 { t.Error("different payloads produced same hash") } } func TestIsWriteMethod(t *testing.T) { tests := []struct { method string want bool }{ // Write methods {"/nostr.v1.NostrRelay/PublishEvent", true}, {"/nostr.v1.NostrRelay/DeleteEvent", true}, {"/admin.v1.Admin/CreateUser", true}, {"/admin.v1.Admin/UpdateSettings", true}, {"/data.v1.Data/InsertRecord", true}, {"/data.v1.Data/RemoveItem", true}, {"/storage.v1.Storage/SetValue", true}, {"/storage.v1.Storage/PutObject", true}, // Read methods {"/nostr.v1.NostrRelay/QueryEvents", false}, {"/nostr.v1.NostrRelay/Subscribe", false}, {"/nostr.v1.NostrRelay/GetEvent", false}, {"/admin.v1.Admin/ListUsers", false}, {"/health.v1.Health/Check", false}, {"/info.v1.Info/GetRelayInfo", false}, // Edge cases {"", false}, {"/", false}, } for _, tt := range tests { t.Run(tt.method, func(t *testing.T) { got := isWriteMethod(tt.method) if got != tt.want { t.Errorf("isWriteMethod(%q) = %v, want %v", tt.method, got, tt.want) } }) } }