summaryrefslogtreecommitdiffstats
path: root/internal/ratelimit/interceptor.go
blob: b27fe7ed6c4ca88332bffb2176d94e0888eebf59 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package ratelimit

import (
	"context"
	"fmt"

	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/peer"
	"google.golang.org/grpc/status"

	"northwest.io/muxstr/internal/auth"
)

// UnaryInterceptor creates a gRPC unary interceptor for rate limiting.
// It should be chained after the auth interceptor so pubkey is available.
func UnaryInterceptor(limiter *Limiter) grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		// Get identifier (pubkey or IP)
		identifier := getIdentifier(ctx)

		// Check rate limit
		if !limiter.Allow(identifier, info.FullMethod) {
			return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded for %s", identifier)
		}

		return handler(ctx, req)
	}
}

// StreamInterceptor creates a gRPC stream interceptor for rate limiting.
// It should be chained after the auth interceptor so pubkey is available.
func StreamInterceptor(limiter *Limiter) grpc.StreamServerInterceptor {
	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		// Get identifier (pubkey or IP)
		identifier := getIdentifier(ss.Context())

		// Check rate limit
		if !limiter.Allow(identifier, info.FullMethod) {
			return status.Errorf(codes.ResourceExhausted, "rate limit exceeded for %s", identifier)
		}

		return handler(srv, ss)
	}
}

// getIdentifier extracts the user identifier from the context.
// Returns pubkey if authenticated, otherwise returns IP address.
func getIdentifier(ctx context.Context) string {
	// Try to get authenticated pubkey first
	pubkey, ok := auth.PubkeyFromContext(ctx)
	if ok && pubkey != "" {
		return pubkey
	}

	// Fall back to IP address
	return getIPAddress(ctx)
}

// getIPAddress extracts the client IP address from the context.
func getIPAddress(ctx context.Context) string {
	// Try to get from peer info
	p, ok := peer.FromContext(ctx)
	if ok && p.Addr != nil {
		return p.Addr.String()
	}

	// Try to get from metadata (X-Forwarded-For header)
	md, ok := metadata.FromIncomingContext(ctx)
	if ok {
		if xff := md.Get("x-forwarded-for"); len(xff) > 0 {
			return xff[0]
		}
		if xri := md.Get("x-real-ip"); len(xri) > 0 {
			return xri[0]
		}
	}

	return "unknown"
}

// WaitUnaryInterceptor is a variant that waits instead of rejecting when rate limited.
// Use this for critical operations that should never fail due to rate limiting.
// WARNING: This can cause requests to hang if rate limit is never freed.
func WaitUnaryInterceptor(limiter *Limiter) grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		identifier := getIdentifier(ctx)

		// Get user limiters
		userLims := limiter.getUserLimiters(identifier)
		rateLimiter := userLims.getLimiterForMethod(info.FullMethod, limiter.config, identifier)

		// Wait for permission (respects context deadline)
		if err := rateLimiter.Wait(ctx); err != nil {
			return nil, status.Errorf(codes.DeadlineExceeded, "rate limit wait cancelled: %v", err)
		}

		limiter.incrementAllowed()
		return handler(ctx, req)
	}
}

// RetryableError wraps a rate limit error with retry-after information.
type RetryableError struct {
	*status.Status
	RetryAfter float64 // seconds
}

// Error implements the error interface.
func (e *RetryableError) Error() string {
	return fmt.Sprintf("%s (retry after %.1fs)", e.Status.Message(), e.RetryAfter)
}

// UnaryInterceptorWithRetryAfter is like UnaryInterceptor but includes retry-after info.
// Clients can extract this to implement smart backoff.
func UnaryInterceptorWithRetryAfter(limiter *Limiter) grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		identifier := getIdentifier(ctx)

		// Get user limiters
		userLims := limiter.getUserLimiters(identifier)
		rateLimiter := userLims.getLimiterForMethod(info.FullMethod, limiter.config, identifier)

		// Get reservation to check how long to wait
		reservation := rateLimiter.Reserve()
		if !reservation.OK() {
			return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded (burst exhausted)")
		}

		delay := reservation.Delay()
		if delay > 0 {
			// Cancel the reservation since we're not going to wait
			reservation.Cancel()

			limiter.incrementDenied()

			// Return error with retry-after information
			st := status.New(codes.ResourceExhausted, fmt.Sprintf("rate limit exceeded for %s", identifier))
			return nil, &RetryableError{
				Status:     st,
				RetryAfter: delay.Seconds(),
			}
		}

		// No delay needed, proceed
		limiter.incrementAllowed()
		return handler(ctx, req)
	}
}