summaryrefslogtreecommitdiffstats
path: root/internal/ratelimit/interceptor.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/ratelimit/interceptor.go')
-rw-r--r--internal/ratelimit/interceptor.go150
1 files changed, 150 insertions, 0 deletions
diff --git a/internal/ratelimit/interceptor.go b/internal/ratelimit/interceptor.go
new file mode 100644
index 0000000..b27fe7e
--- /dev/null
+++ b/internal/ratelimit/interceptor.go
@@ -0,0 +1,150 @@
1package ratelimit
2
3import (
4 "context"
5 "fmt"
6
7 "google.golang.org/grpc"
8 "google.golang.org/grpc/codes"
9 "google.golang.org/grpc/metadata"
10 "google.golang.org/grpc/peer"
11 "google.golang.org/grpc/status"
12
13 "northwest.io/muxstr/internal/auth"
14)
15
16// UnaryInterceptor creates a gRPC unary interceptor for rate limiting.
17// It should be chained after the auth interceptor so pubkey is available.
18func UnaryInterceptor(limiter *Limiter) grpc.UnaryServerInterceptor {
19 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
20 // Get identifier (pubkey or IP)
21 identifier := getIdentifier(ctx)
22
23 // Check rate limit
24 if !limiter.Allow(identifier, info.FullMethod) {
25 return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded for %s", identifier)
26 }
27
28 return handler(ctx, req)
29 }
30}
31
32// StreamInterceptor creates a gRPC stream interceptor for rate limiting.
33// It should be chained after the auth interceptor so pubkey is available.
34func StreamInterceptor(limiter *Limiter) grpc.StreamServerInterceptor {
35 return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
36 // Get identifier (pubkey or IP)
37 identifier := getIdentifier(ss.Context())
38
39 // Check rate limit
40 if !limiter.Allow(identifier, info.FullMethod) {
41 return status.Errorf(codes.ResourceExhausted, "rate limit exceeded for %s", identifier)
42 }
43
44 return handler(srv, ss)
45 }
46}
47
48// getIdentifier extracts the user identifier from the context.
49// Returns pubkey if authenticated, otherwise returns IP address.
50func getIdentifier(ctx context.Context) string {
51 // Try to get authenticated pubkey first
52 pubkey, ok := auth.PubkeyFromContext(ctx)
53 if ok && pubkey != "" {
54 return pubkey
55 }
56
57 // Fall back to IP address
58 return getIPAddress(ctx)
59}
60
61// getIPAddress extracts the client IP address from the context.
62func getIPAddress(ctx context.Context) string {
63 // Try to get from peer info
64 p, ok := peer.FromContext(ctx)
65 if ok && p.Addr != nil {
66 return p.Addr.String()
67 }
68
69 // Try to get from metadata (X-Forwarded-For header)
70 md, ok := metadata.FromIncomingContext(ctx)
71 if ok {
72 if xff := md.Get("x-forwarded-for"); len(xff) > 0 {
73 return xff[0]
74 }
75 if xri := md.Get("x-real-ip"); len(xri) > 0 {
76 return xri[0]
77 }
78 }
79
80 return "unknown"
81}
82
83// WaitUnaryInterceptor is a variant that waits instead of rejecting when rate limited.
84// Use this for critical operations that should never fail due to rate limiting.
85// WARNING: This can cause requests to hang if rate limit is never freed.
86func WaitUnaryInterceptor(limiter *Limiter) grpc.UnaryServerInterceptor {
87 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
88 identifier := getIdentifier(ctx)
89
90 // Get user limiters
91 userLims := limiter.getUserLimiters(identifier)
92 rateLimiter := userLims.getLimiterForMethod(info.FullMethod, limiter.config, identifier)
93
94 // Wait for permission (respects context deadline)
95 if err := rateLimiter.Wait(ctx); err != nil {
96 return nil, status.Errorf(codes.DeadlineExceeded, "rate limit wait cancelled: %v", err)
97 }
98
99 limiter.incrementAllowed()
100 return handler(ctx, req)
101 }
102}
103
104// RetryableError wraps a rate limit error with retry-after information.
105type RetryableError struct {
106 *status.Status
107 RetryAfter float64 // seconds
108}
109
110// Error implements the error interface.
111func (e *RetryableError) Error() string {
112 return fmt.Sprintf("%s (retry after %.1fs)", e.Status.Message(), e.RetryAfter)
113}
114
115// UnaryInterceptorWithRetryAfter is like UnaryInterceptor but includes retry-after info.
116// Clients can extract this to implement smart backoff.
117func UnaryInterceptorWithRetryAfter(limiter *Limiter) grpc.UnaryServerInterceptor {
118 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
119 identifier := getIdentifier(ctx)
120
121 // Get user limiters
122 userLims := limiter.getUserLimiters(identifier)
123 rateLimiter := userLims.getLimiterForMethod(info.FullMethod, limiter.config, identifier)
124
125 // Get reservation to check how long to wait
126 reservation := rateLimiter.Reserve()
127 if !reservation.OK() {
128 return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded (burst exhausted)")
129 }
130
131 delay := reservation.Delay()
132 if delay > 0 {
133 // Cancel the reservation since we're not going to wait
134 reservation.Cancel()
135
136 limiter.incrementDenied()
137
138 // Return error with retry-after information
139 st := status.New(codes.ResourceExhausted, fmt.Sprintf("rate limit exceeded for %s", identifier))
140 return nil, &RetryableError{
141 Status: st,
142 RetryAfter: delay.Seconds(),
143 }
144 }
145
146 // No delay needed, proceed
147 limiter.incrementAllowed()
148 return handler(ctx, req)
149 }
150}