summaryrefslogtreecommitdiffstats
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/handler/websocket/handler.go54
1 files changed, 53 insertions, 1 deletions
diff --git a/internal/handler/websocket/handler.go b/internal/handler/websocket/handler.go
index 8bd246d..a7b73ec 100644
--- a/internal/handler/websocket/handler.go
+++ b/internal/handler/websocket/handler.go
@@ -31,6 +31,10 @@ type MetricsRecorder interface {
31 SetActiveSubscriptions(count int) 31 SetActiveSubscriptions(count int)
32} 32}
33 33
34type RateLimiter interface {
35 Allow(identifier, method string) bool
36}
37
34type AuthConfig struct { 38type AuthConfig struct {
35 ReadEnabled bool 39 ReadEnabled bool
36 WriteEnabled bool 40 WriteEnabled bool
@@ -41,6 +45,7 @@ type AuthConfig struct {
41type connState struct { 45type connState struct {
42 authenticatedPubkey string 46 authenticatedPubkey string
43 authChallenge string 47 authChallenge string
48 clientIP string
44} 49}
45 50
46type Handler struct { 51type Handler struct {
@@ -48,6 +53,7 @@ type Handler struct {
48 auth AuthStore 53 auth AuthStore
49 subs *subscription.Manager 54 subs *subscription.Manager
50 metrics MetricsRecorder 55 metrics MetricsRecorder
56 limiter RateLimiter
51 authConfig *AuthConfig 57 authConfig *AuthConfig
52 indexData IndexData 58 indexData IndexData
53} 59}
@@ -67,6 +73,10 @@ func (h *Handler) SetAuth(a AuthStore) {
67 h.auth = a 73 h.auth = a
68} 74}
69 75
76func (h *Handler) SetRateLimiter(l RateLimiter) {
77 h.limiter = l
78}
79
70func (h *Handler) SetAuthConfig(cfg *AuthConfig) { 80func (h *Handler) SetAuthConfig(cfg *AuthConfig) {
71 h.authConfig = cfg 81 h.authConfig = cfg
72} 82}
@@ -106,7 +116,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
106 116
107 ctx := r.Context() 117 ctx := r.Context()
108 clientSubs := make(map[string]*subscription.Subscription) 118 clientSubs := make(map[string]*subscription.Subscription)
109 state := &connState{} 119 state := &connState{
120 clientIP: getClientIP(r),
121 }
110 122
111 defer func() { 123 defer func() {
112 count := len(clientSubs) 124 count := len(clientSubs)
@@ -224,6 +236,17 @@ func (h *Handler) handleEvent(ctx context.Context, conn *websocket.Conn, raw []j
224 return nil 236 return nil
225 } 237 }
226 238
239 // Rate limiting - use pubkey if authenticated, otherwise IP
240 if h.limiter != nil {
241 identifier := state.authenticatedPubkey
242 if identifier == "" {
243 identifier = state.clientIP
244 }
245 if !h.limiter.Allow(identifier, "EVENT") {
246 return fmt.Errorf("rate limit exceeded")
247 }
248 }
249
227 var event nostr.Event 250 var event nostr.Event
228 if err := json.Unmarshal(raw[1], &event); err != nil { 251 if err := json.Unmarshal(raw[1], &event); err != nil {
229 return fmt.Errorf("invalid event: %w", err) 252 return fmt.Errorf("invalid event: %w", err)
@@ -286,6 +309,17 @@ func (h *Handler) handleReq(ctx context.Context, conn *websocket.Conn, raw []jso
286 return nil 309 return nil
287 } 310 }
288 311
312 // Rate limiting - use pubkey if authenticated, otherwise IP
313 if h.limiter != nil {
314 identifier := state.authenticatedPubkey
315 if identifier == "" {
316 identifier = state.clientIP
317 }
318 if !h.limiter.Allow(identifier, "REQ") {
319 return fmt.Errorf("rate limit exceeded")
320 }
321 }
322
289 var subID string 323 var subID string
290 if err := json.Unmarshal(raw[1], &subID); err != nil { 324 if err := json.Unmarshal(raw[1], &subID); err != nil {
291 return fmt.Errorf("invalid subscription ID") 325 return fmt.Errorf("invalid subscription ID")
@@ -453,3 +487,21 @@ func (h *Handler) handleAuth(ctx context.Context, conn *websocket.Conn, raw []js
453 h.sendOK(ctx, conn, authEvent.ID, true, "") 487 h.sendOK(ctx, conn, authEvent.ID, true, "")
454 return nil 488 return nil
455} 489}
490
491// getClientIP extracts the real client IP from the HTTP request.
492// Checks X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups),
493// then falls back to RemoteAddr.
494func getClientIP(r *http.Request) string {
495 // Check X-Forwarded-For header (set by reverse proxies like Caddy/nginx)
496 if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
497 return xff
498 }
499
500 // Check X-Real-IP header
501 if xri := r.Header.Get("X-Real-IP"); xri != "" {
502 return xri
503 }
504
505 // Fall back to RemoteAddr (direct connection or proxy IP)
506 return r.RemoteAddr
507}