From 4dbd96ec697196d43ad41eca4fd43d53da46a081 Mon Sep 17 00:00:00 2001 From: bndw Date: Sat, 14 Feb 2026 14:54:30 -0800 Subject: refactor: use AuthStore interface and remove type assertions Replace runtime type assertions with compile-time safe AuthStore interface. Add connState struct for cleaner per-connection state management instead of mutable pointer parameters. Reduce auth challenge TTL from 10min to 2min. - Add AuthStore interface with CreateAuthChallenge and ValidateAndConsumeChallenge - Add connState struct for authenticatedPubkey and authChallenge - Remove fragile type assertion pattern in requireAuth and handleAuth - Add nil checks for auth store with clear error messages - Update Handler to have separate auth field - Wire auth store in main.go when auth is enabled --- internal/handler/websocket/handler.go | 85 +++++++++++++++++++++-------------- internal/storage/auth.go | 2 +- 2 files changed, 52 insertions(+), 35 deletions(-) (limited to 'internal') diff --git a/internal/handler/websocket/handler.go b/internal/handler/websocket/handler.go index 581c434..8bd246d 100644 --- a/internal/handler/websocket/handler.go +++ b/internal/handler/websocket/handler.go @@ -20,6 +20,11 @@ type EventStore interface { ProcessDeletion(context.Context, *pb.Event) error } +type AuthStore interface { + CreateAuthChallenge(context.Context) (string, error) + ValidateAndConsumeChallenge(context.Context, string) error +} + type MetricsRecorder interface { IncrementSubscriptions() DecrementSubscriptions() @@ -27,18 +32,24 @@ type MetricsRecorder interface { } type AuthConfig struct { - ReadEnabled bool - WriteEnabled bool - ReadAllowedPubkeys []string + ReadEnabled bool + WriteEnabled bool + ReadAllowedPubkeys []string WriteAllowedPubkeys []string } +type connState struct { + authenticatedPubkey string + authChallenge string +} + type Handler struct { - store EventStore - subs *subscription.Manager - metrics MetricsRecorder + store EventStore + auth AuthStore + subs *subscription.Manager + metrics MetricsRecorder authConfig *AuthConfig - indexData IndexData + indexData IndexData } func NewHandler(store EventStore, subs *subscription.Manager) *Handler { @@ -52,6 +63,10 @@ func (h *Handler) SetMetrics(m MetricsRecorder) { h.metrics = m } +func (h *Handler) SetAuth(a AuthStore) { + h.auth = a +} + func (h *Handler) SetAuthConfig(cfg *AuthConfig) { h.authConfig = cfg } @@ -91,8 +106,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() clientSubs := make(map[string]*subscription.Subscription) - var authenticatedPubkey string - var authChallenge string + state := &connState{} defer func() { count := len(clientSubs) @@ -112,14 +126,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if err := h.handleMessage(ctx, conn, data, clientSubs, &authenticatedPubkey, &authChallenge); err != nil { + if err := h.handleMessage(ctx, conn, data, clientSubs, state); err != nil { log.Printf("Message handling error: %v", err) h.sendNotice(ctx, conn, err.Error()) } } } -func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data []byte, clientSubs map[string]*subscription.Subscription, authenticatedPubkey *string, authChallenge *string) error { +func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data []byte, clientSubs map[string]*subscription.Subscription, state *connState) error { var raw []json.RawMessage if err := json.Unmarshal(data, &raw); err != nil { return fmt.Errorf("invalid JSON") @@ -136,19 +150,19 @@ func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data switch msgType { case "EVENT": - return h.handleEvent(ctx, conn, raw, authenticatedPubkey, authChallenge) + return h.handleEvent(ctx, conn, raw, state) case "REQ": - return h.handleReq(ctx, conn, raw, clientSubs, authenticatedPubkey, authChallenge) + return h.handleReq(ctx, conn, raw, clientSubs, state) case "CLOSE": return h.handleClose(raw, clientSubs) case "AUTH": - return h.handleAuth(ctx, conn, raw, authenticatedPubkey, authChallenge) + return h.handleAuth(ctx, conn, raw, state) default: return fmt.Errorf("unknown message type: %s", msgType) } } -func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite bool, authenticatedPubkey *string, authChallenge *string) error { +func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite bool, state *connState) error { authRequired := false var allowedPubkeys []string @@ -166,15 +180,16 @@ func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite return nil } - if *authenticatedPubkey == "" { - if *authChallenge == "" { - challenge, err := h.store.(interface { - CreateAuthChallenge(context.Context) (string, error) - }).CreateAuthChallenge(ctx) + if state.authenticatedPubkey == "" { + if state.authChallenge == "" { + if h.auth == nil { + return fmt.Errorf("auth required but no auth store configured") + } + challenge, err := h.auth.CreateAuthChallenge(ctx) if err != nil { return fmt.Errorf("failed to create auth challenge: %w", err) } - *authChallenge = challenge + state.authChallenge = challenge h.sendAuthChallenge(ctx, conn, challenge) } return nil @@ -183,7 +198,7 @@ func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite if len(allowedPubkeys) > 0 { allowed := false for _, pk := range allowedPubkeys { - if pk == *authenticatedPubkey { + if pk == state.authenticatedPubkey { allowed = true break } @@ -196,16 +211,16 @@ func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite return nil } -func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, authenticatedPubkey *string, authChallenge *string) error { +func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, state *connState) error { if len(raw) != 2 { return fmt.Errorf("EVENT expects 2 elements") } - if err := h.requireAuth(ctx, conn, true, authenticatedPubkey, authChallenge); err != nil { + if err := h.requireAuth(ctx, conn, true, state); err != nil { return err } - if *authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.WriteEnabled { + if state.authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.WriteEnabled { return nil } @@ -258,16 +273,16 @@ func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []j return nil } -func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, clientSubs map[string]*subscription.Subscription, authenticatedPubkey *string, authChallenge *string) error { +func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, clientSubs map[string]*subscription.Subscription, state *connState) error { if len(raw) < 3 { return fmt.Errorf("REQ expects at least 3 elements") } - if err := h.requireAuth(ctx, conn, false, authenticatedPubkey, authChallenge); err != nil { + if err := h.requireAuth(ctx, conn, false, state); err != nil { return err } - if *authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.ReadEnabled { + if state.authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.ReadEnabled { return nil } @@ -396,7 +411,7 @@ func (h *Handler) sendAuthChallenge(ctx context.Context, conn *websocket.Conn, c return conn.Write(ctx, websocket.MessageText, data) } -func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, authenticatedPubkey *string, authChallenge *string) error { +func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, state *connState) error { if len(raw) != 2 { return fmt.Errorf("AUTH expects 2 elements") } @@ -420,17 +435,19 @@ func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []js } eventChallenge := challengeTag.Value() - if eventChallenge != *authChallenge { + if eventChallenge != state.authChallenge { return fmt.Errorf("challenge mismatch") } - if err := h.store.(interface { - ValidateAndConsumeChallenge(context.Context, string) error - }).ValidateAndConsumeChallenge(ctx, eventChallenge); err != nil { + if h.auth == nil { + return fmt.Errorf("auth required but no auth store configured") + } + + if err := h.auth.ValidateAndConsumeChallenge(ctx, eventChallenge); err != nil { return fmt.Errorf("invalid challenge: %w", err) } - *authenticatedPubkey = authEvent.PubKey + state.authenticatedPubkey = authEvent.PubKey log.Printf("WebSocket client authenticated: %s", authEvent.PubKey[:16]) h.sendOK(ctx, conn, authEvent.ID, true, "") diff --git a/internal/storage/auth.go b/internal/storage/auth.go index 6eefa41..e17ffeb 100644 --- a/internal/storage/auth.go +++ b/internal/storage/auth.go @@ -10,7 +10,7 @@ import ( const ( ChallengeLength = 32 // bytes - ChallengeTTL = 10 * time.Minute + ChallengeTTL = 2 * time.Minute ) func generateChallenge() (string, error) { -- cgit v1.2.3