package websocket import ( "context" "encoding/json" "fmt" "log" "net/http" pb "northwest.io/muxstr/api/nostr/v1" "northwest.io/muxstr/internal/storage" "northwest.io/muxstr/internal/subscription" "northwest.io/muxstr/internal/websocket" "northwest.io/nostr" ) type EventStore interface { StoreEvent(context.Context, *storage.EventData) error QueryEvents(context.Context, []*pb.Filter, *storage.QueryOptions) ([]*pb.Event, error) ProcessDeletion(context.Context, *pb.Event) error } type MetricsRecorder interface { IncrementSubscriptions() DecrementSubscriptions() SetActiveSubscriptions(count int) } type AuthConfig struct { ReadEnabled bool WriteEnabled bool ReadAllowedPubkeys []string WriteAllowedPubkeys []string } type Handler struct { store EventStore subs *subscription.Manager metrics MetricsRecorder authConfig *AuthConfig indexData IndexData } func NewHandler(store EventStore, subs *subscription.Manager) *Handler { return &Handler{ store: store, subs: subs, } } func (h *Handler) SetMetrics(m MetricsRecorder) { h.metrics = m } func (h *Handler) SetAuthConfig(cfg *AuthConfig) { h.authConfig = cfg } // SetIndexData sets the addresses for the index page func (h *Handler) SetIndexData(grpcAddr, httpAddr, wsAddr string) { h.indexData = IndexData{ GrpcAddr: grpcAddr, HttpAddr: httpAddr, WsAddr: wsAddr, } } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Handle GET requests (but not WebSocket upgrades) if r.Method == "GET" && r.Header.Get("Upgrade") != "websocket" { accept := r.Header.Get("Accept") // NIP-11: Relay information document if accept == "application/nostr+json" { h.ServeNIP11(w, r) return } // Serve HTML index for browsers h.ServeIndex(w, r, h.indexData) return } // Handle WebSocket upgrade conn, err := websocket.Accept(w, r) if err != nil { log.Printf("WebSocket accept failed: %v", err) return } defer conn.Close(websocket.StatusNormalClosure, "") ctx := r.Context() clientSubs := make(map[string]*subscription.Subscription) var authenticatedPubkey string var authChallenge string defer func() { count := len(clientSubs) for subID := range clientSubs { h.subs.Remove(subID) } if h.metrics != nil && count > 0 { for i := 0; i < count; i++ { h.metrics.DecrementSubscriptions() } } }() for { _, data, err := conn.Read(ctx) if err != nil { return } if err := h.handleMessage(ctx, conn, data, clientSubs, &authenticatedPubkey, &authChallenge); err != nil { log.Printf("Message handling error: %v", err) h.sendNotice(ctx, conn, err.Error()) } } } func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data []byte, clientSubs map[string]*subscription.Subscription, authenticatedPubkey *string, authChallenge *string) error { var raw []json.RawMessage if err := json.Unmarshal(data, &raw); err != nil { return fmt.Errorf("invalid JSON") } if len(raw) == 0 { return fmt.Errorf("empty message") } var msgType string if err := json.Unmarshal(raw[0], &msgType); err != nil { return fmt.Errorf("invalid message type") } switch msgType { case "EVENT": return h.handleEvent(ctx, conn, raw, authenticatedPubkey, authChallenge) case "REQ": return h.handleReq(ctx, conn, raw, clientSubs, authenticatedPubkey, authChallenge) case "CLOSE": return h.handleClose(raw, clientSubs) case "AUTH": return h.handleAuth(ctx, conn, raw, authenticatedPubkey, authChallenge) default: return fmt.Errorf("unknown message type: %s", msgType) } } func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite bool, authenticatedPubkey *string, authChallenge *string) error { authRequired := false var allowedPubkeys []string if h.authConfig != nil { if isWrite && h.authConfig.WriteEnabled { authRequired = true allowedPubkeys = h.authConfig.WriteAllowedPubkeys } else if !isWrite && h.authConfig.ReadEnabled { authRequired = true allowedPubkeys = h.authConfig.ReadAllowedPubkeys } } if !authRequired { return nil } if *authenticatedPubkey == "" { if *authChallenge == "" { challenge, err := h.store.(interface { CreateAuthChallenge(context.Context) (string, error) }).CreateAuthChallenge(ctx) if err != nil { return fmt.Errorf("failed to create auth challenge: %w", err) } *authChallenge = challenge h.sendAuthChallenge(ctx, conn, challenge) } return fmt.Errorf("restricted: authentication required") } if len(allowedPubkeys) > 0 { allowed := false for _, pk := range allowedPubkeys { if pk == *authenticatedPubkey { allowed = true break } } if !allowed { return fmt.Errorf("restricted: pubkey not authorized") } } return nil } func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, authenticatedPubkey *string, authChallenge *string) error { if len(raw) != 2 { return fmt.Errorf("EVENT expects 2 elements") } if err := h.requireAuth(ctx, conn, true, authenticatedPubkey, authChallenge); err != nil { return err } var event nostr.Event if err := json.Unmarshal(raw[1], &event); err != nil { return fmt.Errorf("invalid event: %w", err) } if !event.CheckID() { h.sendOK(ctx, conn, event.ID, false, "invalid: event ID mismatch") return nil } if !event.Verify() { h.sendOK(ctx, conn, event.ID, false, "invalid: signature verification failed") return nil } pbEvent := NostrToPB(&event) canonicalJSON := event.Serialize() // Handle deletion events (kind 5) - process but don't store if pbEvent.Kind == 5 { if err := h.store.ProcessDeletion(ctx, pbEvent); err != nil { h.sendOK(ctx, conn, event.ID, false, fmt.Sprintf("deletion failed: %v", err)) return nil } h.sendOK(ctx, conn, event.ID, true, "deleted") return nil } eventData := &storage.EventData{ Event: pbEvent, CanonicalJSON: canonicalJSON, } err := h.store.StoreEvent(ctx, eventData) if err == storage.ErrEventExists { h.sendOK(ctx, conn, event.ID, true, "duplicate: already have this event") return nil } if err != nil { h.sendOK(ctx, conn, event.ID, false, fmt.Sprintf("error: %v", err)) return nil } h.subs.MatchAndFan(pbEvent) h.sendOK(ctx, conn, event.ID, true, "") return nil } func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, clientSubs map[string]*subscription.Subscription, authenticatedPubkey *string, authChallenge *string) error { if len(raw) < 3 { return fmt.Errorf("REQ expects at least 3 elements") } if err := h.requireAuth(ctx, conn, false, authenticatedPubkey, authChallenge); err != nil { return err } var subID string if err := json.Unmarshal(raw[1], &subID); err != nil { return fmt.Errorf("invalid subscription ID") } var filters []*pb.Filter for i := 2; i < len(raw); i++ { var nostrFilter nostr.Filter if err := json.Unmarshal(raw[i], &nostrFilter); err != nil { return fmt.Errorf("invalid filter: %w", err) } pbFilter := NostrFilterToPB(&nostrFilter) filters = append(filters, pbFilter) } if existing, ok := clientSubs[subID]; ok { h.subs.Remove(existing.ID) delete(clientSubs, subID) if h.metrics != nil { h.metrics.DecrementSubscriptions() } } storedEvents, err := h.store.QueryEvents(ctx, filters, &storage.QueryOptions{Limit: 0}) if err != nil { return fmt.Errorf("query failed: %w", err) } for _, pbEvent := range storedEvents { event := PBToNostr(pbEvent) h.sendEvent(ctx, conn, subID, event) } h.sendEOSE(ctx, conn, subID) sub := &subscription.Subscription{ ID: subID, Filters: filters, Events: make(chan *pb.Event, 100), } sub.InitDone() h.subs.Add(sub) clientSubs[subID] = sub if h.metrics != nil { h.metrics.IncrementSubscriptions() } go h.streamEvents(ctx, conn, sub) return nil } func (h *Handler) handleClose(raw []json.RawMessage, clientSubs map[string]*subscription.Subscription) error { if len(raw) != 2 { return fmt.Errorf("CLOSE expects 2 elements") } var subID string if err := json.Unmarshal(raw[1], &subID); err != nil { return fmt.Errorf("invalid subscription ID") } if sub, ok := clientSubs[subID]; ok { h.subs.Remove(sub.ID) delete(clientSubs, subID) if h.metrics != nil { h.metrics.DecrementSubscriptions() } } return nil } func (h *Handler) streamEvents(ctx context.Context, conn *websocket.Conn, sub *subscription.Subscription) { for { select { case pbEvent, ok := <-sub.Events: if !ok { return } event := PBToNostr(pbEvent) h.sendEvent(ctx, conn, sub.ID, event) case <-ctx.Done(): return case <-sub.Done(): return } } } func (h *Handler) sendEvent(ctx context.Context, conn *websocket.Conn, subID string, event *nostr.Event) error { msg := []interface{}{"EVENT", subID, event} data, _ := json.Marshal(msg) return conn.Write(ctx, websocket.MessageText, data) } func (h *Handler) sendOK(ctx context.Context, conn *websocket.Conn, eventID string, accepted bool, message string) error { msg := []interface{}{"OK", eventID, accepted, message} data, _ := json.Marshal(msg) return conn.Write(ctx, websocket.MessageText, data) } func (h *Handler) sendEOSE(ctx context.Context, conn *websocket.Conn, subID string) error { msg := []interface{}{"EOSE", subID} data, _ := json.Marshal(msg) return conn.Write(ctx, websocket.MessageText, data) } func (h *Handler) sendNotice(ctx context.Context, conn *websocket.Conn, notice string) error { msg := []interface{}{"NOTICE", notice} data, _ := json.Marshal(msg) return conn.Write(ctx, websocket.MessageText, data) } func (h *Handler) sendAuthChallenge(ctx context.Context, conn *websocket.Conn, challenge string) error { msg := []interface{}{"AUTH", challenge} data, _ := json.Marshal(msg) return conn.Write(ctx, websocket.MessageText, data) } func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, authenticatedPubkey *string, authChallenge *string) error { if len(raw) != 2 { return fmt.Errorf("AUTH expects 2 elements") } var authEvent nostr.Event if err := json.Unmarshal(raw[1], &authEvent); err != nil { return fmt.Errorf("invalid auth event: %w", err) } if authEvent.Kind != 22242 { return fmt.Errorf("invalid auth event kind: expected 22242, got %d", authEvent.Kind) } if !authEvent.Verify() { return fmt.Errorf("invalid auth event signature") } challengeTag := authEvent.Tags.Find("challenge") if challengeTag == nil { return fmt.Errorf("missing challenge tag in auth event") } eventChallenge := challengeTag.Value() if eventChallenge != *authChallenge { return fmt.Errorf("challenge mismatch") } if err := h.store.(interface { ValidateAndConsumeChallenge(context.Context, string) error }).ValidateAndConsumeChallenge(ctx, eventChallenge); err != nil { return fmt.Errorf("invalid challenge: %w", err) } *authenticatedPubkey = authEvent.PubKey log.Printf("WebSocket client authenticated: %s", authEvent.PubKey[:16]) h.sendOK(ctx, conn, authEvent.ID, true, "") return nil }