diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/handler/websocket/handler.go | 85 | ||||
| -rw-r--r-- | internal/storage/auth.go | 2 |
2 files changed, 52 insertions, 35 deletions
diff --git a/internal/handler/websocket/handler.go b/internal/handler/websocket/handler.go index 581c434..8bd246d 100644 --- a/internal/handler/websocket/handler.go +++ b/internal/handler/websocket/handler.go | |||
| @@ -20,6 +20,11 @@ type EventStore interface { | |||
| 20 | ProcessDeletion(context.Context, *pb.Event) error | 20 | ProcessDeletion(context.Context, *pb.Event) error |
| 21 | } | 21 | } |
| 22 | 22 | ||
| 23 | type AuthStore interface { | ||
| 24 | CreateAuthChallenge(context.Context) (string, error) | ||
| 25 | ValidateAndConsumeChallenge(context.Context, string) error | ||
| 26 | } | ||
| 27 | |||
| 23 | type MetricsRecorder interface { | 28 | type MetricsRecorder interface { |
| 24 | IncrementSubscriptions() | 29 | IncrementSubscriptions() |
| 25 | DecrementSubscriptions() | 30 | DecrementSubscriptions() |
| @@ -27,18 +32,24 @@ type MetricsRecorder interface { | |||
| 27 | } | 32 | } |
| 28 | 33 | ||
| 29 | type AuthConfig struct { | 34 | type AuthConfig struct { |
| 30 | ReadEnabled bool | 35 | ReadEnabled bool |
| 31 | WriteEnabled bool | 36 | WriteEnabled bool |
| 32 | ReadAllowedPubkeys []string | 37 | ReadAllowedPubkeys []string |
| 33 | WriteAllowedPubkeys []string | 38 | WriteAllowedPubkeys []string |
| 34 | } | 39 | } |
| 35 | 40 | ||
| 41 | type connState struct { | ||
| 42 | authenticatedPubkey string | ||
| 43 | authChallenge string | ||
| 44 | } | ||
| 45 | |||
| 36 | type Handler struct { | 46 | type Handler struct { |
| 37 | store EventStore | 47 | store EventStore |
| 38 | subs *subscription.Manager | 48 | auth AuthStore |
| 39 | metrics MetricsRecorder | 49 | subs *subscription.Manager |
| 50 | metrics MetricsRecorder | ||
| 40 | authConfig *AuthConfig | 51 | authConfig *AuthConfig |
| 41 | indexData IndexData | 52 | indexData IndexData |
| 42 | } | 53 | } |
| 43 | 54 | ||
| 44 | func NewHandler(store EventStore, subs *subscription.Manager) *Handler { | 55 | func NewHandler(store EventStore, subs *subscription.Manager) *Handler { |
| @@ -52,6 +63,10 @@ func (h *Handler) SetMetrics(m MetricsRecorder) { | |||
| 52 | h.metrics = m | 63 | h.metrics = m |
| 53 | } | 64 | } |
| 54 | 65 | ||
| 66 | func (h *Handler) SetAuth(a AuthStore) { | ||
| 67 | h.auth = a | ||
| 68 | } | ||
| 69 | |||
| 55 | func (h *Handler) SetAuthConfig(cfg *AuthConfig) { | 70 | func (h *Handler) SetAuthConfig(cfg *AuthConfig) { |
| 56 | h.authConfig = cfg | 71 | h.authConfig = cfg |
| 57 | } | 72 | } |
| @@ -91,8 +106,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||
| 91 | 106 | ||
| 92 | ctx := r.Context() | 107 | ctx := r.Context() |
| 93 | clientSubs := make(map[string]*subscription.Subscription) | 108 | clientSubs := make(map[string]*subscription.Subscription) |
| 94 | var authenticatedPubkey string | 109 | state := &connState{} |
| 95 | var authChallenge string | ||
| 96 | 110 | ||
| 97 | defer func() { | 111 | defer func() { |
| 98 | count := len(clientSubs) | 112 | count := len(clientSubs) |
| @@ -112,14 +126,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||
| 112 | return | 126 | return |
| 113 | } | 127 | } |
| 114 | 128 | ||
| 115 | if err := h.handleMessage(ctx, conn, data, clientSubs, &authenticatedPubkey, &authChallenge); err != nil { | 129 | if err := h.handleMessage(ctx, conn, data, clientSubs, state); err != nil { |
| 116 | log.Printf("Message handling error: %v", err) | 130 | log.Printf("Message handling error: %v", err) |
| 117 | h.sendNotice(ctx, conn, err.Error()) | 131 | h.sendNotice(ctx, conn, err.Error()) |
| 118 | } | 132 | } |
| 119 | } | 133 | } |
| 120 | } | 134 | } |
| 121 | 135 | ||
| 122 | func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data []byte, clientSubs map[string]*subscription.Subscription, authenticatedPubkey *string, authChallenge *string) error { | 136 | func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data []byte, clientSubs map[string]*subscription.Subscription, state *connState) error { |
| 123 | var raw []json.RawMessage | 137 | var raw []json.RawMessage |
| 124 | if err := json.Unmarshal(data, &raw); err != nil { | 138 | if err := json.Unmarshal(data, &raw); err != nil { |
| 125 | return fmt.Errorf("invalid JSON") | 139 | return fmt.Errorf("invalid JSON") |
| @@ -136,19 +150,19 @@ func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data | |||
| 136 | 150 | ||
| 137 | switch msgType { | 151 | switch msgType { |
| 138 | case "EVENT": | 152 | case "EVENT": |
| 139 | return h.handleEvent(ctx, conn, raw, authenticatedPubkey, authChallenge) | 153 | return h.handleEvent(ctx, conn, raw, state) |
| 140 | case "REQ": | 154 | case "REQ": |
| 141 | return h.handleReq(ctx, conn, raw, clientSubs, authenticatedPubkey, authChallenge) | 155 | return h.handleReq(ctx, conn, raw, clientSubs, state) |
| 142 | case "CLOSE": | 156 | case "CLOSE": |
| 143 | return h.handleClose(raw, clientSubs) | 157 | return h.handleClose(raw, clientSubs) |
| 144 | case "AUTH": | 158 | case "AUTH": |
| 145 | return h.handleAuth(ctx, conn, raw, authenticatedPubkey, authChallenge) | 159 | return h.handleAuth(ctx, conn, raw, state) |
| 146 | default: | 160 | default: |
| 147 | return fmt.Errorf("unknown message type: %s", msgType) | 161 | return fmt.Errorf("unknown message type: %s", msgType) |
| 148 | } | 162 | } |
| 149 | } | 163 | } |
| 150 | 164 | ||
| 151 | func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite bool, authenticatedPubkey *string, authChallenge *string) error { | 165 | func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite bool, state *connState) error { |
| 152 | authRequired := false | 166 | authRequired := false |
| 153 | var allowedPubkeys []string | 167 | var allowedPubkeys []string |
| 154 | 168 | ||
| @@ -166,15 +180,16 @@ func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite | |||
| 166 | return nil | 180 | return nil |
| 167 | } | 181 | } |
| 168 | 182 | ||
| 169 | if *authenticatedPubkey == "" { | 183 | if state.authenticatedPubkey == "" { |
| 170 | if *authChallenge == "" { | 184 | if state.authChallenge == "" { |
| 171 | challenge, err := h.store.(interface { | 185 | if h.auth == nil { |
| 172 | CreateAuthChallenge(context.Context) (string, error) | 186 | return fmt.Errorf("auth required but no auth store configured") |
| 173 | }).CreateAuthChallenge(ctx) | 187 | } |
| 188 | challenge, err := h.auth.CreateAuthChallenge(ctx) | ||
| 174 | if err != nil { | 189 | if err != nil { |
| 175 | return fmt.Errorf("failed to create auth challenge: %w", err) | 190 | return fmt.Errorf("failed to create auth challenge: %w", err) |
| 176 | } | 191 | } |
| 177 | *authChallenge = challenge | 192 | state.authChallenge = challenge |
| 178 | h.sendAuthChallenge(ctx, conn, challenge) | 193 | h.sendAuthChallenge(ctx, conn, challenge) |
| 179 | } | 194 | } |
| 180 | return nil | 195 | return nil |
| @@ -183,7 +198,7 @@ func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite | |||
| 183 | if len(allowedPubkeys) > 0 { | 198 | if len(allowedPubkeys) > 0 { |
| 184 | allowed := false | 199 | allowed := false |
| 185 | for _, pk := range allowedPubkeys { | 200 | for _, pk := range allowedPubkeys { |
| 186 | if pk == *authenticatedPubkey { | 201 | if pk == state.authenticatedPubkey { |
| 187 | allowed = true | 202 | allowed = true |
| 188 | break | 203 | break |
| 189 | } | 204 | } |
| @@ -196,16 +211,16 @@ func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite | |||
| 196 | return nil | 211 | return nil |
| 197 | } | 212 | } |
| 198 | 213 | ||
| 199 | func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, authenticatedPubkey *string, authChallenge *string) error { | 214 | func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, state *connState) error { |
| 200 | if len(raw) != 2 { | 215 | if len(raw) != 2 { |
| 201 | return fmt.Errorf("EVENT expects 2 elements") | 216 | return fmt.Errorf("EVENT expects 2 elements") |
| 202 | } | 217 | } |
| 203 | 218 | ||
| 204 | if err := h.requireAuth(ctx, conn, true, authenticatedPubkey, authChallenge); err != nil { | 219 | if err := h.requireAuth(ctx, conn, true, state); err != nil { |
| 205 | return err | 220 | return err |
| 206 | } | 221 | } |
| 207 | 222 | ||
| 208 | if *authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.WriteEnabled { | 223 | if state.authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.WriteEnabled { |
| 209 | return nil | 224 | return nil |
| 210 | } | 225 | } |
| 211 | 226 | ||
| @@ -258,16 +273,16 @@ func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []j | |||
| 258 | return nil | 273 | return nil |
| 259 | } | 274 | } |
| 260 | 275 | ||
| 261 | func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, clientSubs map[string]*subscription.Subscription, authenticatedPubkey *string, authChallenge *string) error { | 276 | func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, clientSubs map[string]*subscription.Subscription, state *connState) error { |
| 262 | if len(raw) < 3 { | 277 | if len(raw) < 3 { |
| 263 | return fmt.Errorf("REQ expects at least 3 elements") | 278 | return fmt.Errorf("REQ expects at least 3 elements") |
| 264 | } | 279 | } |
| 265 | 280 | ||
| 266 | if err := h.requireAuth(ctx, conn, false, authenticatedPubkey, authChallenge); err != nil { | 281 | if err := h.requireAuth(ctx, conn, false, state); err != nil { |
| 267 | return err | 282 | return err |
| 268 | } | 283 | } |
| 269 | 284 | ||
| 270 | if *authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.ReadEnabled { | 285 | if state.authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.ReadEnabled { |
| 271 | return nil | 286 | return nil |
| 272 | } | 287 | } |
| 273 | 288 | ||
| @@ -396,7 +411,7 @@ func (h *Handler) sendAuthChallenge(ctx context.Context, conn *websocket.Conn, c | |||
| 396 | return conn.Write(ctx, websocket.MessageText, data) | 411 | return conn.Write(ctx, websocket.MessageText, data) |
| 397 | } | 412 | } |
| 398 | 413 | ||
| 399 | func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, authenticatedPubkey *string, authChallenge *string) error { | 414 | func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, state *connState) error { |
| 400 | if len(raw) != 2 { | 415 | if len(raw) != 2 { |
| 401 | return fmt.Errorf("AUTH expects 2 elements") | 416 | return fmt.Errorf("AUTH expects 2 elements") |
| 402 | } | 417 | } |
| @@ -420,17 +435,19 @@ func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []js | |||
| 420 | } | 435 | } |
| 421 | 436 | ||
| 422 | eventChallenge := challengeTag.Value() | 437 | eventChallenge := challengeTag.Value() |
| 423 | if eventChallenge != *authChallenge { | 438 | if eventChallenge != state.authChallenge { |
| 424 | return fmt.Errorf("challenge mismatch") | 439 | return fmt.Errorf("challenge mismatch") |
| 425 | } | 440 | } |
| 426 | 441 | ||
| 427 | if err := h.store.(interface { | 442 | if h.auth == nil { |
| 428 | ValidateAndConsumeChallenge(context.Context, string) error | 443 | return fmt.Errorf("auth required but no auth store configured") |
| 429 | }).ValidateAndConsumeChallenge(ctx, eventChallenge); err != nil { | 444 | } |
| 445 | |||
| 446 | if err := h.auth.ValidateAndConsumeChallenge(ctx, eventChallenge); err != nil { | ||
| 430 | return fmt.Errorf("invalid challenge: %w", err) | 447 | return fmt.Errorf("invalid challenge: %w", err) |
| 431 | } | 448 | } |
| 432 | 449 | ||
| 433 | *authenticatedPubkey = authEvent.PubKey | 450 | state.authenticatedPubkey = authEvent.PubKey |
| 434 | log.Printf("WebSocket client authenticated: %s", authEvent.PubKey[:16]) | 451 | log.Printf("WebSocket client authenticated: %s", authEvent.PubKey[:16]) |
| 435 | 452 | ||
| 436 | h.sendOK(ctx, conn, authEvent.ID, true, "") | 453 | h.sendOK(ctx, conn, authEvent.ID, true, "") |
diff --git a/internal/storage/auth.go b/internal/storage/auth.go index 6eefa41..e17ffeb 100644 --- a/internal/storage/auth.go +++ b/internal/storage/auth.go | |||
| @@ -10,7 +10,7 @@ import ( | |||
| 10 | 10 | ||
| 11 | const ( | 11 | const ( |
| 12 | ChallengeLength = 32 // bytes | 12 | ChallengeLength = 32 // bytes |
| 13 | ChallengeTTL = 10 * time.Minute | 13 | ChallengeTTL = 2 * time.Minute |
| 14 | ) | 14 | ) |
| 15 | 15 | ||
| 16 | func generateChallenge() (string, error) { | 16 | func generateChallenge() (string, error) { |
