diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/handler/websocket/handler.go | 132 | ||||
| -rw-r--r-- | internal/storage/auth.go | 88 |
2 files changed, 214 insertions, 6 deletions
diff --git a/internal/handler/websocket/handler.go b/internal/handler/websocket/handler.go index b7ea71d..c8fb6cc 100644 --- a/internal/handler/websocket/handler.go +++ b/internal/handler/websocket/handler.go | |||
| @@ -26,10 +26,18 @@ type MetricsRecorder interface { | |||
| 26 | SetActiveSubscriptions(count int) | 26 | SetActiveSubscriptions(count int) |
| 27 | } | 27 | } |
| 28 | 28 | ||
| 29 | type AuthConfig struct { | ||
| 30 | ReadEnabled bool | ||
| 31 | WriteEnabled bool | ||
| 32 | ReadAllowedPubkeys []string | ||
| 33 | WriteAllowedPubkeys []string | ||
| 34 | } | ||
| 35 | |||
| 29 | type Handler struct { | 36 | type Handler struct { |
| 30 | store EventStore | 37 | store EventStore |
| 31 | subs *subscription.Manager | 38 | subs *subscription.Manager |
| 32 | metrics MetricsRecorder | 39 | metrics MetricsRecorder |
| 40 | authConfig *AuthConfig | ||
| 33 | indexData IndexData | 41 | indexData IndexData |
| 34 | } | 42 | } |
| 35 | 43 | ||
| @@ -44,6 +52,10 @@ func (h *Handler) SetMetrics(m MetricsRecorder) { | |||
| 44 | h.metrics = m | 52 | h.metrics = m |
| 45 | } | 53 | } |
| 46 | 54 | ||
| 55 | func (h *Handler) SetAuthConfig(cfg *AuthConfig) { | ||
| 56 | h.authConfig = cfg | ||
| 57 | } | ||
| 58 | |||
| 47 | // SetIndexData sets the addresses for the index page | 59 | // SetIndexData sets the addresses for the index page |
| 48 | func (h *Handler) SetIndexData(grpcAddr, httpAddr, wsAddr string) { | 60 | func (h *Handler) SetIndexData(grpcAddr, httpAddr, wsAddr string) { |
| 49 | h.indexData = IndexData{ | 61 | h.indexData = IndexData{ |
| @@ -79,6 +91,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||
| 79 | 91 | ||
| 80 | ctx := r.Context() | 92 | ctx := r.Context() |
| 81 | clientSubs := make(map[string]*subscription.Subscription) | 93 | clientSubs := make(map[string]*subscription.Subscription) |
| 94 | var authenticatedPubkey string | ||
| 95 | var authChallenge string | ||
| 96 | |||
| 82 | defer func() { | 97 | defer func() { |
| 83 | count := len(clientSubs) | 98 | count := len(clientSubs) |
| 84 | for subID := range clientSubs { | 99 | for subID := range clientSubs { |
| @@ -97,14 +112,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||
| 97 | return | 112 | return |
| 98 | } | 113 | } |
| 99 | 114 | ||
| 100 | if err := h.handleMessage(ctx, conn, data, clientSubs); err != nil { | 115 | if err := h.handleMessage(ctx, conn, data, clientSubs, &authenticatedPubkey, &authChallenge); err != nil { |
| 101 | log.Printf("Message handling error: %v", err) | 116 | log.Printf("Message handling error: %v", err) |
| 102 | h.sendNotice(ctx, conn, err.Error()) | 117 | h.sendNotice(ctx, conn, err.Error()) |
| 103 | } | 118 | } |
| 104 | } | 119 | } |
| 105 | } | 120 | } |
| 106 | 121 | ||
| 107 | func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data []byte, clientSubs map[string]*subscription.Subscription) error { | 122 | func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data []byte, clientSubs map[string]*subscription.Subscription, authenticatedPubkey *string, authChallenge *string) error { |
| 108 | var raw []json.RawMessage | 123 | var raw []json.RawMessage |
| 109 | if err := json.Unmarshal(data, &raw); err != nil { | 124 | if err := json.Unmarshal(data, &raw); err != nil { |
| 110 | return fmt.Errorf("invalid JSON") | 125 | return fmt.Errorf("invalid JSON") |
| @@ -121,21 +136,75 @@ func (h *Handler) handleMessage(ctx context.Context, conn *websocket.Conn, data | |||
| 121 | 136 | ||
| 122 | switch msgType { | 137 | switch msgType { |
| 123 | case "EVENT": | 138 | case "EVENT": |
| 124 | return h.handleEvent(ctx, conn, raw) | 139 | return h.handleEvent(ctx, conn, raw, authenticatedPubkey, authChallenge) |
| 125 | case "REQ": | 140 | case "REQ": |
| 126 | return h.handleReq(ctx, conn, raw, clientSubs) | 141 | return h.handleReq(ctx, conn, raw, clientSubs, authenticatedPubkey, authChallenge) |
| 127 | case "CLOSE": | 142 | case "CLOSE": |
| 128 | return h.handleClose(raw, clientSubs) | 143 | return h.handleClose(raw, clientSubs) |
| 144 | case "AUTH": | ||
| 145 | return h.handleAuth(ctx, conn, raw, authenticatedPubkey, authChallenge) | ||
| 129 | default: | 146 | default: |
| 130 | return fmt.Errorf("unknown message type: %s", msgType) | 147 | return fmt.Errorf("unknown message type: %s", msgType) |
| 131 | } | 148 | } |
| 132 | } | 149 | } |
| 133 | 150 | ||
| 134 | func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage) error { | 151 | func (h *Handler) requireAuth(ctx context.Context, conn *websocket.Conn, isWrite bool, authenticatedPubkey *string, authChallenge *string) error { |
| 152 | authRequired := false | ||
| 153 | var allowedPubkeys []string | ||
| 154 | |||
| 155 | if h.authConfig != nil { | ||
| 156 | if isWrite && h.authConfig.WriteEnabled { | ||
| 157 | authRequired = true | ||
| 158 | allowedPubkeys = h.authConfig.WriteAllowedPubkeys | ||
| 159 | } else if !isWrite && h.authConfig.ReadEnabled { | ||
| 160 | authRequired = true | ||
| 161 | allowedPubkeys = h.authConfig.ReadAllowedPubkeys | ||
| 162 | } | ||
| 163 | } | ||
| 164 | |||
| 165 | if !authRequired { | ||
| 166 | return nil | ||
| 167 | } | ||
| 168 | |||
| 169 | if *authenticatedPubkey == "" { | ||
| 170 | if *authChallenge == "" { | ||
| 171 | challenge, err := h.store.(interface { | ||
| 172 | CreateAuthChallenge(context.Context) (string, error) | ||
| 173 | }).CreateAuthChallenge(ctx) | ||
| 174 | if err != nil { | ||
| 175 | return fmt.Errorf("failed to create auth challenge: %w", err) | ||
| 176 | } | ||
| 177 | *authChallenge = challenge | ||
| 178 | h.sendAuthChallenge(ctx, conn, challenge) | ||
| 179 | } | ||
| 180 | return fmt.Errorf("restricted: authentication required") | ||
| 181 | } | ||
| 182 | |||
| 183 | if len(allowedPubkeys) > 0 { | ||
| 184 | allowed := false | ||
| 185 | for _, pk := range allowedPubkeys { | ||
| 186 | if pk == *authenticatedPubkey { | ||
| 187 | allowed = true | ||
| 188 | break | ||
| 189 | } | ||
| 190 | } | ||
| 191 | if !allowed { | ||
| 192 | return fmt.Errorf("restricted: pubkey not authorized") | ||
| 193 | } | ||
| 194 | } | ||
| 195 | |||
| 196 | return nil | ||
| 197 | } | ||
| 198 | |||
| 199 | func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, authenticatedPubkey *string, authChallenge *string) error { | ||
| 135 | if len(raw) != 2 { | 200 | if len(raw) != 2 { |
| 136 | return fmt.Errorf("EVENT expects 2 elements") | 201 | return fmt.Errorf("EVENT expects 2 elements") |
| 137 | } | 202 | } |
| 138 | 203 | ||
| 204 | if err := h.requireAuth(ctx, conn, true, authenticatedPubkey, authChallenge); err != nil { | ||
| 205 | return err | ||
| 206 | } | ||
| 207 | |||
| 139 | var event nostr.Event | 208 | var event nostr.Event |
| 140 | if err := json.Unmarshal(raw[1], &event); err != nil { | 209 | if err := json.Unmarshal(raw[1], &event); err != nil { |
| 141 | return fmt.Errorf("invalid event: %w", err) | 210 | return fmt.Errorf("invalid event: %w", err) |
| @@ -185,11 +254,15 @@ func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []j | |||
| 185 | return nil | 254 | return nil |
| 186 | } | 255 | } |
| 187 | 256 | ||
| 188 | func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, clientSubs map[string]*subscription.Subscription) error { | 257 | func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, clientSubs map[string]*subscription.Subscription, authenticatedPubkey *string, authChallenge *string) error { |
| 189 | if len(raw) < 3 { | 258 | if len(raw) < 3 { |
| 190 | return fmt.Errorf("REQ expects at least 3 elements") | 259 | return fmt.Errorf("REQ expects at least 3 elements") |
| 191 | } | 260 | } |
| 192 | 261 | ||
| 262 | if err := h.requireAuth(ctx, conn, false, authenticatedPubkey, authChallenge); err != nil { | ||
| 263 | return err | ||
| 264 | } | ||
| 265 | |||
| 193 | var subID string | 266 | var subID string |
| 194 | if err := json.Unmarshal(raw[1], &subID); err != nil { | 267 | if err := json.Unmarshal(raw[1], &subID); err != nil { |
| 195 | return fmt.Errorf("invalid subscription ID") | 268 | return fmt.Errorf("invalid subscription ID") |
| @@ -308,3 +381,50 @@ func (h *Handler) sendNotice(ctx context.Context, conn *websocket.Conn, notice s | |||
| 308 | data, _ := json.Marshal(msg) | 381 | data, _ := json.Marshal(msg) |
| 309 | return conn.Write(ctx, websocket.MessageText, data) | 382 | return conn.Write(ctx, websocket.MessageText, data) |
| 310 | } | 383 | } |
| 384 | |||
| 385 | func (h *Handler) sendAuthChallenge(ctx context.Context, conn *websocket.Conn, challenge string) error { | ||
| 386 | msg := []interface{}{"AUTH", challenge} | ||
| 387 | data, _ := json.Marshal(msg) | ||
| 388 | return conn.Write(ctx, websocket.MessageText, data) | ||
| 389 | } | ||
| 390 | |||
| 391 | func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []json.RawMessage, authenticatedPubkey *string, authChallenge *string) error { | ||
| 392 | if len(raw) != 2 { | ||
| 393 | return fmt.Errorf("AUTH expects 2 elements") | ||
| 394 | } | ||
| 395 | |||
| 396 | var authEvent nostr.Event | ||
| 397 | if err := json.Unmarshal(raw[1], &authEvent); err != nil { | ||
| 398 | return fmt.Errorf("invalid auth event: %w", err) | ||
| 399 | } | ||
| 400 | |||
| 401 | if authEvent.Kind != 22242 { | ||
| 402 | return fmt.Errorf("invalid auth event kind: expected 22242, got %d", authEvent.Kind) | ||
| 403 | } | ||
| 404 | |||
| 405 | if !authEvent.Verify() { | ||
| 406 | return fmt.Errorf("invalid auth event signature") | ||
| 407 | } | ||
| 408 | |||
| 409 | challengeTag := authEvent.Tags.Find("challenge") | ||
| 410 | if challengeTag == nil { | ||
| 411 | return fmt.Errorf("missing challenge tag in auth event") | ||
| 412 | } | ||
| 413 | |||
| 414 | eventChallenge := challengeTag.Value() | ||
| 415 | if eventChallenge != *authChallenge { | ||
| 416 | return fmt.Errorf("challenge mismatch") | ||
| 417 | } | ||
| 418 | |||
| 419 | if err := h.store.(interface { | ||
| 420 | ValidateAndConsumeChallenge(context.Context, string) error | ||
| 421 | }).ValidateAndConsumeChallenge(ctx, eventChallenge); err != nil { | ||
| 422 | return fmt.Errorf("invalid challenge: %w", err) | ||
| 423 | } | ||
| 424 | |||
| 425 | *authenticatedPubkey = authEvent.PubKey | ||
| 426 | log.Printf("WebSocket client authenticated: %s", authEvent.PubKey[:16]) | ||
| 427 | |||
| 428 | h.sendOK(ctx, conn, authEvent.ID, true, "") | ||
| 429 | return nil | ||
| 430 | } | ||
diff --git a/internal/storage/auth.go b/internal/storage/auth.go new file mode 100644 index 0000000..6eefa41 --- /dev/null +++ b/internal/storage/auth.go | |||
| @@ -0,0 +1,88 @@ | |||
| 1 | package storage | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "context" | ||
| 5 | "crypto/rand" | ||
| 6 | "encoding/hex" | ||
| 7 | "fmt" | ||
| 8 | "time" | ||
| 9 | ) | ||
| 10 | |||
| 11 | const ( | ||
| 12 | ChallengeLength = 32 // bytes | ||
| 13 | ChallengeTTL = 10 * time.Minute | ||
| 14 | ) | ||
| 15 | |||
| 16 | func generateChallenge() (string, error) { | ||
| 17 | bytes := make([]byte, ChallengeLength) | ||
| 18 | if _, err := rand.Read(bytes); err != nil { | ||
| 19 | return "", fmt.Errorf("failed to generate challenge: %w", err) | ||
| 20 | } | ||
| 21 | return hex.EncodeToString(bytes), nil | ||
| 22 | } | ||
| 23 | |||
| 24 | func (s *Storage) CreateAuthChallenge(ctx context.Context) (string, error) { | ||
| 25 | challenge, err := generateChallenge() | ||
| 26 | if err != nil { | ||
| 27 | return "", err | ||
| 28 | } | ||
| 29 | |||
| 30 | now := time.Now().Unix() | ||
| 31 | expiresAt := time.Now().Add(ChallengeTTL).Unix() | ||
| 32 | |||
| 33 | _, err = s.db.ExecContext(ctx, | ||
| 34 | "INSERT INTO auth_challenges (challenge, created_at, expires_at, used) VALUES (?, ?, ?, 0)", | ||
| 35 | challenge, now, expiresAt, | ||
| 36 | ) | ||
| 37 | if err != nil { | ||
| 38 | return "", fmt.Errorf("failed to store challenge: %w", err) | ||
| 39 | } | ||
| 40 | |||
| 41 | return challenge, nil | ||
| 42 | } | ||
| 43 | |||
| 44 | func (s *Storage) ValidateAndConsumeChallenge(ctx context.Context, challenge string) error { | ||
| 45 | tx, err := s.db.BeginTx(ctx, nil) | ||
| 46 | if err != nil { | ||
| 47 | return fmt.Errorf("failed to begin transaction: %w", err) | ||
| 48 | } | ||
| 49 | defer tx.Rollback() | ||
| 50 | |||
| 51 | var expiresAt int64 | ||
| 52 | var used int | ||
| 53 | err = tx.QueryRowContext(ctx, | ||
| 54 | "SELECT expires_at, used FROM auth_challenges WHERE challenge = ?", | ||
| 55 | challenge, | ||
| 56 | ).Scan(&expiresAt, &used) | ||
| 57 | |||
| 58 | if err != nil { | ||
| 59 | return fmt.Errorf("challenge not found or invalid") | ||
| 60 | } | ||
| 61 | |||
| 62 | if used != 0 { | ||
| 63 | return fmt.Errorf("challenge already used") | ||
| 64 | } | ||
| 65 | |||
| 66 | if time.Now().Unix() > expiresAt { | ||
| 67 | return fmt.Errorf("challenge expired") | ||
| 68 | } | ||
| 69 | |||
| 70 | _, err = tx.ExecContext(ctx, | ||
| 71 | "UPDATE auth_challenges SET used = 1 WHERE challenge = ?", | ||
| 72 | challenge, | ||
| 73 | ) | ||
| 74 | if err != nil { | ||
| 75 | return fmt.Errorf("failed to mark challenge as used: %w", err) | ||
| 76 | } | ||
| 77 | |||
| 78 | return tx.Commit() | ||
| 79 | } | ||
| 80 | |||
| 81 | func (s *Storage) CleanupExpiredChallenges(ctx context.Context) error { | ||
| 82 | now := time.Now().Unix() | ||
| 83 | _, err := s.db.ExecContext(ctx, | ||
| 84 | "DELETE FROM auth_challenges WHERE expires_at < ?", | ||
| 85 | now, | ||
| 86 | ) | ||
| 87 | return err | ||
| 88 | } | ||
