From 32ca0fba5108d0dc2c7415f36e55f031d5a0562e Mon Sep 17 00:00:00 2001 From: bndw Date: Sat, 14 Feb 2026 21:53:14 -0800 Subject: feat: add rate limiting to WebSocket connections WebSocket clients were completely unprotected from abuse. Add RateLimiter interface to WebSocket handler and enforce limits on EVENT and REQ messages. - Add RateLimiter interface with Allow(identifier, method) method - Track client IP in connState (proxy-aware via X-Forwarded-For) - Check rate limits in handleEvent and handleReq - Use authenticated pubkey as identifier, fallback to IP - Share same rate limiter instance with gRPC - Add getClientIP() helper that checks proxy headers first Critical security fix for production deployment. Without this, any client could spam unlimited events/subscriptions via WebSocket. --- cmd/relay/main.go | 7 ++++- internal/handler/websocket/handler.go | 54 ++++++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/cmd/relay/main.go b/cmd/relay/main.go index e4afec8..86a29cb 100644 --- a/cmd/relay/main.go +++ b/cmd/relay/main.go @@ -79,8 +79,9 @@ func main() { streamInterceptors = append(streamInterceptors, auth.NostrStreamInterceptor(authOpts)) } + var limiter *ratelimit.Limiter if cfg.RateLimit.Enabled { - limiter := ratelimit.New(cfg.RateLimit.ToRateLimiter()) + limiter = ratelimit.New(cfg.RateLimit.ToRateLimiter()) unaryInterceptors = append(unaryInterceptors, ratelimit.UnaryInterceptor(limiter)) streamInterceptors = append(streamInterceptors, ratelimit.StreamInterceptor(limiter)) } @@ -125,6 +126,10 @@ func main() { }) } + if limiter != nil { + wsHandler.SetRateLimiter(limiter) + } + var grpcDisplay, httpDisplay, wsDisplay string if cfg.Server.PublicURL != "" { grpcDisplay = cfg.Server.PublicURL + ":443" diff --git a/internal/handler/websocket/handler.go b/internal/handler/websocket/handler.go index 8bd246d..a7b73ec 100644 --- a/internal/handler/websocket/handler.go +++ b/internal/handler/websocket/handler.go @@ -31,6 +31,10 @@ type MetricsRecorder interface { SetActiveSubscriptions(count int) } +type RateLimiter interface { + Allow(identifier, method string) bool +} + type AuthConfig struct { ReadEnabled bool WriteEnabled bool @@ -41,6 +45,7 @@ type AuthConfig struct { type connState struct { authenticatedPubkey string authChallenge string + clientIP string } type Handler struct { @@ -48,6 +53,7 @@ type Handler struct { auth AuthStore subs *subscription.Manager metrics MetricsRecorder + limiter RateLimiter authConfig *AuthConfig indexData IndexData } @@ -67,6 +73,10 @@ func (h *Handler) SetAuth(a AuthStore) { h.auth = a } +func (h *Handler) SetRateLimiter(l RateLimiter) { + h.limiter = l +} + func (h *Handler) SetAuthConfig(cfg *AuthConfig) { h.authConfig = cfg } @@ -106,7 +116,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() clientSubs := make(map[string]*subscription.Subscription) - state := &connState{} + state := &connState{ + clientIP: getClientIP(r), + } defer func() { count := len(clientSubs) @@ -224,6 +236,17 @@ func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []j return nil } + // Rate limiting - use pubkey if authenticated, otherwise IP + if h.limiter != nil { + identifier := state.authenticatedPubkey + if identifier == "" { + identifier = state.clientIP + } + if !h.limiter.Allow(identifier, "EVENT") { + return fmt.Errorf("rate limit exceeded") + } + } + var event nostr.Event if err := json.Unmarshal(raw[1], &event); err != nil { return fmt.Errorf("invalid event: %w", err) @@ -286,6 +309,17 @@ func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []jso return nil } + // Rate limiting - use pubkey if authenticated, otherwise IP + if h.limiter != nil { + identifier := state.authenticatedPubkey + if identifier == "" { + identifier = state.clientIP + } + if !h.limiter.Allow(identifier, "REQ") { + return fmt.Errorf("rate limit exceeded") + } + } + var subID string if err := json.Unmarshal(raw[1], &subID); err != nil { return fmt.Errorf("invalid subscription ID") @@ -453,3 +487,21 @@ func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []js h.sendOK(ctx, conn, authEvent.ID, true, "") return nil } + +// getClientIP extracts the real client IP from the HTTP request. +// Checks X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups), +// then falls back to RemoteAddr. +func getClientIP(r *http.Request) string { + // Check X-Forwarded-For header (set by reverse proxies like Caddy/nginx) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + return xff + } + + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // Fall back to RemoteAddr (direct connection or proxy IP) + return r.RemoteAddr +} -- cgit v1.2.3