summaryrefslogtreecommitdiffstats
path: root/internal/handler/websocket
diff options
context:
space:
mode:
Diffstat (limited to 'internal/handler/websocket')
-rw-r--r--internal/handler/websocket/handler.go85
1 files changed, 51 insertions, 34 deletions
diff --git a/internal/handler/websocket/handler.go b/internal/handler/websocket/handler.go
index 581c434..8bd246d 100644
--- a/internal/handler/websocket/handler.go
+++ b/internal/handler/websocket/handler.go
@@ -20,6 +20,11 @@ type EventStore interface {
20 ProcessDeletion(context.Context, *pb.Event) error 20 ProcessDeletion(context.Context, *pb.Event) error
21} 21}
22 22
23type AuthStore interface {
24 CreateAuthChallenge(context.Context) (string, error)
25 ValidateAndConsumeChallenge(context.Context, string) error
26}
27
23type MetricsRecorder interface { 28type MetricsRecorder interface {
24 IncrementSubscriptions() 29 IncrementSubscriptions()
25 DecrementSubscriptions() 30 DecrementSubscriptions()
@@ -27,18 +32,24 @@ type MetricsRecorder interface {
27} 32}
28 33
29type AuthConfig struct { 34type AuthConfig struct {
30 ReadEnabled bool 35 ReadEnabled bool
31 WriteEnabled bool 36 WriteEnabled bool
32 ReadAllowedPubkeys []string 37 ReadAllowedPubkeys []string
33 WriteAllowedPubkeys []string 38 WriteAllowedPubkeys []string
34} 39}
35 40
41type connState struct {
42 authenticatedPubkey string
43 authChallenge string
44}
45
36type Handler struct { 46type Handler struct {
37 store EventStore 47 store EventStore
38 subs *subscription.Manager 48 auth AuthStore
39 metrics MetricsRecorder 49 subs *subscription.Manager
50 metrics MetricsRecorder
40 authConfig *AuthConfig 51 authConfig *AuthConfig
41 indexData IndexData 52 indexData IndexData
42} 53}
43 54
44func NewHandler(store EventStore, subs *subscription.Manager) *Handler { 55func NewHandler(store EventStore, subs *subscription.Manager) *Handler {
@@ -52,6 +63,10 @@ func (h *Handler) SetMetrics(m MetricsRecorder) {
52 h.metrics = m 63 h.metrics = m
53} 64}
54 65
66func (h *Handler) SetAuth(a AuthStore) {
67 h.auth = a
68}
69
55func (h *Handler) SetAuthConfig(cfg *AuthConfig) { 70func (h *Handler) SetAuthConfig(cfg *AuthConfig) {
56 h.authConfig = cfg 71 h.authConfig = cfg
57} 72}
@@ -91,8 +106,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
91 106
92 ctx := r.Context() 107 ctx := r.Context()
93 clientSubs := make(map[string]*subscription.Subscription) 108 clientSubs := make(map[string]*subscription.Subscription)
94 var authenticatedPubkey string 109 state := &connState{}
95 var authChallenge string
96 110
97 defer func() { 111 defer func() {
98 count := len(clientSubs) 112 count := len(clientSubs)
@@ -112,14 +126,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
112 return 126 return
113 } 127 }
114 128
115 if err := h.handleMessage(ctx, conn, data, clientSubs, &authenticatedPubkey, &authChallenge); err != nil { 129 if err := h.handleMessage(ctx, conn, data, clientSubs, state); err != nil {
116 log.Printf("Message handling error: %v", err) 130 log.Printf("Message handling error: %v", err)
117 h.sendNotice(ctx, conn, err.Error()) 131 h.sendNotice(ctx, conn, err.Error())
118 } 132 }
119 } 133 }
120} 134}
121 135
122func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data []byte, clientSubs map[string]*subscription.Subscription, authenticatedPubkey *string, authChallenge *string) error { 136func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data []byte, clientSubs map[string]*subscription.Subscription, state *connState) error {
123 var raw []json.RawMessage 137 var raw []json.RawMessage
124 if err := json.Unmarshal(data, &raw); err != nil { 138 if err := json.Unmarshal(data, &raw); err != nil {
125 return fmt.Errorf("invalid JSON") 139 return fmt.Errorf("invalid JSON")
@@ -136,19 +150,19 @@ func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data
136 150
137 switch msgType { 151 switch msgType {
138 case "EVENT": 152 case "EVENT":
139 return h.handleEvent(ctx, conn, raw, authenticatedPubkey, authChallenge) 153 return h.handleEvent(ctx, conn, raw, state)
140 case "REQ": 154 case "REQ":
141 return h.handleReq(ctx, conn, raw, clientSubs, authenticatedPubkey, authChallenge) 155 return h.handleReq(ctx, conn, raw, clientSubs, state)
142 case "CLOSE": 156 case "CLOSE":
143 return h.handleClose(raw, clientSubs) 157 return h.handleClose(raw, clientSubs)
144 case "AUTH": 158 case "AUTH":
145 return h.handleAuth(ctx, conn, raw, authenticatedPubkey, authChallenge) 159 return h.handleAuth(ctx, conn, raw, state)
146 default: 160 default:
147 return fmt.Errorf("unknown message type: %s", msgType) 161 return fmt.Errorf("unknown message type: %s", msgType)
148 } 162 }
149} 163}
150 164
151func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite bool, authenticatedPubkey *string, authChallenge *string) error { 165func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite bool, state *connState) error {
152 authRequired := false 166 authRequired := false
153 var allowedPubkeys []string 167 var allowedPubkeys []string
154 168
@@ -166,15 +180,16 @@ func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite
166 return nil 180 return nil
167 } 181 }
168 182
169 if *authenticatedPubkey == "" { 183 if state.authenticatedPubkey == "" {
170 if *authChallenge == "" { 184 if state.authChallenge == "" {
171 challenge, err := h.store.(interface { 185 if h.auth == nil {
172 CreateAuthChallenge(context.Context) (string, error) 186 return fmt.Errorf("auth required but no auth store configured")
173 }).CreateAuthChallenge(ctx) 187 }
188 challenge, err := h.auth.CreateAuthChallenge(ctx)
174 if err != nil { 189 if err != nil {
175 return fmt.Errorf("failed to create auth challenge: %w", err) 190 return fmt.Errorf("failed to create auth challenge: %w", err)
176 } 191 }
177 *authChallenge = challenge 192 state.authChallenge = challenge
178 h.sendAuthChallenge(ctx, conn, challenge) 193 h.sendAuthChallenge(ctx, conn, challenge)
179 } 194 }
180 return nil 195 return nil
@@ -183,7 +198,7 @@ func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite
183 if len(allowedPubkeys) > 0 { 198 if len(allowedPubkeys) > 0 {
184 allowed := false 199 allowed := false
185 for _, pk := range allowedPubkeys { 200 for _, pk := range allowedPubkeys {
186 if pk == *authenticatedPubkey { 201 if pk == state.authenticatedPubkey {
187 allowed = true 202 allowed = true
188 break 203 break
189 } 204 }
@@ -196,16 +211,16 @@ func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite
196 return nil 211 return nil
197} 212}
198 213
199func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, authenticatedPubkey *string, authChallenge *string) error { 214func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, state *connState) error {
200 if len(raw) != 2 { 215 if len(raw) != 2 {
201 return fmt.Errorf("EVENT expects 2 elements") 216 return fmt.Errorf("EVENT expects 2 elements")
202 } 217 }
203 218
204 if err := h.requireAuth(ctx, conn, true, authenticatedPubkey, authChallenge); err != nil { 219 if err := h.requireAuth(ctx, conn, true, state); err != nil {
205 return err 220 return err
206 } 221 }
207 222
208 if *authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.WriteEnabled { 223 if state.authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.WriteEnabled {
209 return nil 224 return nil
210 } 225 }
211 226
@@ -258,16 +273,16 @@ func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []j
258 return nil 273 return nil
259} 274}
260 275
261func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, clientSubs map[string]*subscription.Subscription, authenticatedPubkey *string, authChallenge *string) error { 276func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, clientSubs map[string]*subscription.Subscription, state *connState) error {
262 if len(raw) < 3 { 277 if len(raw) < 3 {
263 return fmt.Errorf("REQ expects at least 3 elements") 278 return fmt.Errorf("REQ expects at least 3 elements")
264 } 279 }
265 280
266 if err := h.requireAuth(ctx, conn, false, authenticatedPubkey, authChallenge); err != nil { 281 if err := h.requireAuth(ctx, conn, false, state); err != nil {
267 return err 282 return err
268 } 283 }
269 284
270 if *authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.ReadEnabled { 285 if state.authenticatedPubkey == "" && h.authConfig != nil && h.authConfig.ReadEnabled {
271 return nil 286 return nil
272 } 287 }
273 288
@@ -396,7 +411,7 @@ func (h *Handler) sendAuthChallenge(ctx context.Context, conn *websocket.Conn, c
396 return conn.Write(ctx, websocket.MessageText, data) 411 return conn.Write(ctx, websocket.MessageText, data)
397} 412}
398 413
399func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, authenticatedPubkey *string, authChallenge *string) error { 414func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, state *connState) error {
400 if len(raw) != 2 { 415 if len(raw) != 2 {
401 return fmt.Errorf("AUTH expects 2 elements") 416 return fmt.Errorf("AUTH expects 2 elements")
402 } 417 }
@@ -420,17 +435,19 @@ func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []js
420 } 435 }
421 436
422 eventChallenge := challengeTag.Value() 437 eventChallenge := challengeTag.Value()
423 if eventChallenge != *authChallenge { 438 if eventChallenge != state.authChallenge {
424 return fmt.Errorf("challenge mismatch") 439 return fmt.Errorf("challenge mismatch")
425 } 440 }
426 441
427 if err := h.store.(interface { 442 if h.auth == nil {
428 ValidateAndConsumeChallenge(context.Context, string) error 443 return fmt.Errorf("auth required but no auth store configured")
429 }).ValidateAndConsumeChallenge(ctx, eventChallenge); err != nil { 444 }
445
446 if err := h.auth.ValidateAndConsumeChallenge(ctx, eventChallenge); err != nil {
430 return fmt.Errorf("invalid challenge: %w", err) 447 return fmt.Errorf("invalid challenge: %w", err)
431 } 448 }
432 449
433 *authenticatedPubkey = authEvent.PubKey 450 state.authenticatedPubkey = authEvent.PubKey
434 log.Printf("WebSocket client authenticated: %s", authEvent.PubKey[:16]) 451 log.Printf("WebSocket client authenticated: %s", authEvent.PubKey[:16])
435 452
436 h.sendOK(ctx, conn, authEvent.ID, true, "") 453 h.sendOK(ctx, conn, authEvent.ID, true, "")