diff options
| author | bndw <ben@bdw.to> | 2026-02-14 21:53:14 -0800 |
|---|---|---|
| committer | bndw <ben@bdw.to> | 2026-02-14 21:53:14 -0800 |
| commit | 32ca0fba5108d0dc2c7415f36e55f031d5a0562e (patch) | |
| tree | ff91309ef9af7d0ec8c9b5bd5b6f39f073d4be55 /internal/handler/websocket | |
| parent | e647880669b79cd968231cf85dc037a18e8bfd9c (diff) | |
feat: add rate limiting to WebSocket connections
WebSocket clients were completely unprotected from abuse. Add RateLimiter
interface to WebSocket handler and enforce limits on EVENT and REQ messages.
- Add RateLimiter interface with Allow(identifier, method) method
- Track client IP in connState (proxy-aware via X-Forwarded-For)
- Check rate limits in handleEvent and handleReq
- Use authenticated pubkey as identifier, fallback to IP
- Share same rate limiter instance with gRPC
- Add getClientIP() helper that checks proxy headers first
Critical security fix for production deployment. Without this, any client
could spam unlimited events/subscriptions via WebSocket.
Diffstat (limited to 'internal/handler/websocket')
| -rw-r--r-- | internal/handler/websocket/handler.go | 54 |
1 files changed, 53 insertions, 1 deletions
diff --git a/internal/handler/websocket/handler.go b/internal/handler/websocket/handler.go index 8bd246d..a7b73ec 100644 --- a/internal/handler/websocket/handler.go +++ b/internal/handler/websocket/handler.go | |||
| @@ -31,6 +31,10 @@ type MetricsRecorder interface { | |||
| 31 | SetActiveSubscriptions(count int) | 31 | SetActiveSubscriptions(count int) |
| 32 | } | 32 | } |
| 33 | 33 | ||
| 34 | type RateLimiter interface { | ||
| 35 | Allow(identifier, method string) bool | ||
| 36 | } | ||
| 37 | |||
| 34 | type AuthConfig struct { | 38 | type AuthConfig struct { |
| 35 | ReadEnabled bool | 39 | ReadEnabled bool |
| 36 | WriteEnabled bool | 40 | WriteEnabled bool |
| @@ -41,6 +45,7 @@ type AuthConfig struct { | |||
| 41 | type connState struct { | 45 | type connState struct { |
| 42 | authenticatedPubkey string | 46 | authenticatedPubkey string |
| 43 | authChallenge string | 47 | authChallenge string |
| 48 | clientIP string | ||
| 44 | } | 49 | } |
| 45 | 50 | ||
| 46 | type Handler struct { | 51 | type Handler struct { |
| @@ -48,6 +53,7 @@ type Handler struct { | |||
| 48 | auth AuthStore | 53 | auth AuthStore |
| 49 | subs *subscription.Manager | 54 | subs *subscription.Manager |
| 50 | metrics MetricsRecorder | 55 | metrics MetricsRecorder |
| 56 | limiter RateLimiter | ||
| 51 | authConfig *AuthConfig | 57 | authConfig *AuthConfig |
| 52 | indexData IndexData | 58 | indexData IndexData |
| 53 | } | 59 | } |
| @@ -67,6 +73,10 @@ func (h *Handler) SetAuth(a AuthStore) { | |||
| 67 | h.auth = a | 73 | h.auth = a |
| 68 | } | 74 | } |
| 69 | 75 | ||
| 76 | func (h *Handler) SetRateLimiter(l RateLimiter) { | ||
| 77 | h.limiter = l | ||
| 78 | } | ||
| 79 | |||
| 70 | func (h *Handler) SetAuthConfig(cfg *AuthConfig) { | 80 | func (h *Handler) SetAuthConfig(cfg *AuthConfig) { |
| 71 | h.authConfig = cfg | 81 | h.authConfig = cfg |
| 72 | } | 82 | } |
| @@ -106,7 +116,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||
| 106 | 116 | ||
| 107 | ctx := r.Context() | 117 | ctx := r.Context() |
| 108 | clientSubs := make(map[string]*subscription.Subscription) | 118 | clientSubs := make(map[string]*subscription.Subscription) |
| 109 | state := &connState{} | 119 | state := &connState{ |
| 120 | clientIP: getClientIP(r), | ||
| 121 | } | ||
| 110 | 122 | ||
| 111 | defer func() { | 123 | defer func() { |
| 112 | count := len(clientSubs) | 124 | count := len(clientSubs) |
| @@ -224,6 +236,17 @@ func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []j | |||
| 224 | return nil | 236 | return nil |
| 225 | } | 237 | } |
| 226 | 238 | ||
| 239 | // Rate limiting - use pubkey if authenticated, otherwise IP | ||
| 240 | if h.limiter != nil { | ||
| 241 | identifier := state.authenticatedPubkey | ||
| 242 | if identifier == "" { | ||
| 243 | identifier = state.clientIP | ||
| 244 | } | ||
| 245 | if !h.limiter.Allow(identifier, "EVENT") { | ||
| 246 | return fmt.Errorf("rate limit exceeded") | ||
| 247 | } | ||
| 248 | } | ||
| 249 | |||
| 227 | var event nostr.Event | 250 | var event nostr.Event |
| 228 | if err := json.Unmarshal(raw[1], &event); err != nil { | 251 | if err := json.Unmarshal(raw[1], &event); err != nil { |
| 229 | return fmt.Errorf("invalid event: %w", err) | 252 | return fmt.Errorf("invalid event: %w", err) |
| @@ -286,6 +309,17 @@ func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []jso | |||
| 286 | return nil | 309 | return nil |
| 287 | } | 310 | } |
| 288 | 311 | ||
| 312 | // Rate limiting - use pubkey if authenticated, otherwise IP | ||
| 313 | if h.limiter != nil { | ||
| 314 | identifier := state.authenticatedPubkey | ||
| 315 | if identifier == "" { | ||
| 316 | identifier = state.clientIP | ||
| 317 | } | ||
| 318 | if !h.limiter.Allow(identifier, "REQ") { | ||
| 319 | return fmt.Errorf("rate limit exceeded") | ||
| 320 | } | ||
| 321 | } | ||
| 322 | |||
| 289 | var subID string | 323 | var subID string |
| 290 | if err := json.Unmarshal(raw[1], &subID); err != nil { | 324 | if err := json.Unmarshal(raw[1], &subID); err != nil { |
| 291 | return fmt.Errorf("invalid subscription ID") | 325 | return fmt.Errorf("invalid subscription ID") |
| @@ -453,3 +487,21 @@ func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []js | |||
| 453 | h.sendOK(ctx, conn, authEvent.ID, true, "") | 487 | h.sendOK(ctx, conn, authEvent.ID, true, "") |
| 454 | return nil | 488 | return nil |
| 455 | } | 489 | } |
| 490 | |||
| 491 | // getClientIP extracts the real client IP from the HTTP request. | ||
| 492 | // Checks X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups), | ||
| 493 | // then falls back to RemoteAddr. | ||
| 494 | func getClientIP(r *http.Request) string { | ||
| 495 | // Check X-Forwarded-For header (set by reverse proxies like Caddy/nginx) | ||
| 496 | if xff := r.Header.Get("X-Forwarded-For"); xff != "" { | ||
| 497 | return xff | ||
| 498 | } | ||
| 499 | |||
| 500 | // Check X-Real-IP header | ||
| 501 | if xri := r.Header.Get("X-Real-IP"); xri != "" { | ||
| 502 | return xri | ||
| 503 | } | ||
| 504 | |||
| 505 | // Fall back to RemoteAddr (direct connection or proxy IP) | ||
| 506 | return r.RemoteAddr | ||
| 507 | } | ||
