summaryrefslogtreecommitdiffstats
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/ratelimit/README.md341
-rw-r--r--internal/ratelimit/config.go153
-rw-r--r--internal/ratelimit/interceptor.go150
-rw-r--r--internal/ratelimit/limiter.go279
-rw-r--r--internal/ratelimit/ratelimit_test.go438
5 files changed, 1361 insertions, 0 deletions
diff --git a/internal/ratelimit/README.md b/internal/ratelimit/README.md
new file mode 100644
index 0000000..a7f248d
--- /dev/null
+++ b/internal/ratelimit/README.md
@@ -0,0 +1,341 @@
1# Rate Limiting
2
3This package provides per-user rate limiting for gRPC endpoints using the token bucket algorithm.
4
5## Overview
6
7Rate limiting prevents abuse and ensures fair resource allocation across users. This implementation:
8
9- **Per-user quotas**: Different limits for each authenticated pubkey
10- **IP-based fallback**: Rate limit unauthenticated requests by IP address
11- **Method-specific limits**: Different quotas for different operations (e.g., stricter limits for PublishEvent)
12- **Token bucket algorithm**: Allows bursts while maintaining average rate
13- **Standard gRPC errors**: Returns `ResourceExhausted` (HTTP 429) when limits exceeded
14
15## How It Works
16
17### Token Bucket Algorithm
18
19Each user (identified by pubkey or IP) has a "bucket" of tokens:
20
211. **Tokens refill** at a configured rate (e.g., 10 requests/second)
222. **Each request consumes** one token
233. **Bursts allowed** up to bucket capacity (e.g., 20 tokens)
244. **Requests blocked** when bucket is empty
25
26Example with 10 req/s limit and 20 token burst:
27```
28Time 0s: User makes 20 requests → All succeed (burst)
29Time 0s: User makes 21st request → Rejected (bucket empty)
30Time 1s: Bucket refills by 10 tokens
31Time 1s: User makes 10 requests → All succeed
32```
33
34### Integration with Authentication
35
36Rate limiting works seamlessly with the auth package:
37
381. **Authenticated users** (via NIP-98): Rate limited by pubkey
392. **Unauthenticated users**: Rate limited by IP address
403. **Auth interceptor runs first**, making pubkey available to rate limiter
41
42## Usage
43
44### Basic Setup
45
46```go
47import (
48 "northwest.io/muxstr/internal/auth"
49 "northwest.io/muxstr/internal/ratelimit"
50 "google.golang.org/grpc"
51)
52
53// Configure rate limiter
54limiter := ratelimit.New(&ratelimit.Config{
55 // Default: 10 requests/second per user, burst of 20
56 RequestsPerSecond: 10,
57 BurstSize: 20,
58
59 // Unauthenticated users: 5 requests/second per IP
60 IPRequestsPerSecond: 5,
61 IPBurstSize: 10,
62})
63
64// Create server with auth + rate limit interceptors
65server := grpc.NewServer(
66 grpc.ChainUnaryInterceptor(
67 auth.NostrUnaryInterceptor(authOpts), // Auth runs first
68 ratelimit.UnaryInterceptor(limiter), // Rate limit runs second
69 ),
70 grpc.ChainStreamInterceptor(
71 auth.NostrStreamInterceptor(authOpts),
72 ratelimit.StreamInterceptor(limiter),
73 ),
74)
75```
76
77### Method-Specific Limits
78
79Different operations can have different rate limits:
80
81```go
82limiter := ratelimit.New(&ratelimit.Config{
83 // Default for all methods
84 RequestsPerSecond: 10,
85 BurstSize: 20,
86
87 // Override for specific methods
88 MethodLimits: map[string]ratelimit.MethodLimit{
89 "/nostr.v1.NostrRelay/PublishEvent": {
90 RequestsPerSecond: 2, // Stricter: only 2 publishes/sec
91 BurstSize: 5,
92 },
93 "/nostr.v1.NostrRelay/Subscribe": {
94 RequestsPerSecond: 1, // Only 1 new subscription/sec
95 BurstSize: 3,
96 },
97 "/nostr.v1.NostrRelay/QueryEvents": {
98 RequestsPerSecond: 20, // More lenient: 20 queries/sec
99 BurstSize: 50,
100 },
101 },
102})
103```
104
105### Per-User Custom Limits
106
107Set different limits for specific users:
108
109```go
110limiter := ratelimit.New(&ratelimit.Config{
111 RequestsPerSecond: 10,
112 BurstSize: 20,
113
114 // VIP users get higher limits
115 UserLimits: map[string]ratelimit.UserLimit{
116 "vip-pubkey-abc123": {
117 RequestsPerSecond: 100,
118 BurstSize: 200,
119 },
120 "premium-pubkey-def456": {
121 RequestsPerSecond: 50,
122 BurstSize: 100,
123 },
124 },
125})
126```
127
128### Disable Rate Limiting for Specific Methods
129
130```go
131limiter := ratelimit.New(&ratelimit.Config{
132 RequestsPerSecond: 10,
133 BurstSize: 20,
134
135 // Don't rate limit these methods
136 SkipMethods: []string{
137 "/grpc.health.v1.Health/Check",
138 },
139})
140```
141
142## Configuration Reference
143
144### Config
145
146- **`RequestsPerSecond`**: Default rate limit (tokens per second)
147- **`BurstSize`**: Maximum burst size (bucket capacity)
148- **`IPRequestsPerSecond`**: Rate limit for unauthenticated users (per IP)
149- **`IPBurstSize`**: Burst size for IP-based limits
150- **`MethodLimits`**: Map of method-specific overrides
151- **`UserLimits`**: Map of per-user custom limits (by pubkey)
152- **`SkipMethods`**: Methods that bypass rate limiting
153- **`CleanupInterval`**: How often to remove idle limiters (default: 5 minutes)
154
155### MethodLimit
156
157- **`RequestsPerSecond`**: Rate limit for this method
158- **`BurstSize`**: Burst size for this method
159
160### UserLimit
161
162- **`RequestsPerSecond`**: Rate limit for this user
163- **`BurstSize`**: Burst size for this user
164- **`MethodLimits`**: Optional method overrides for this user
165
166## Error Handling
167
168When rate limit is exceeded, the interceptor returns:
169
170```
171Code: ResourceExhausted (HTTP 429)
172Message: "rate limit exceeded for <pubkey/IP>"
173```
174
175Clients should implement exponential backoff:
176
177```go
178for {
179 resp, err := client.PublishEvent(ctx, req)
180 if err != nil {
181 if status.Code(err) == codes.ResourceExhausted {
182 // Rate limited - wait and retry
183 time.Sleep(backoff)
184 backoff *= 2
185 continue
186 }
187 return err
188 }
189 return resp, nil
190}
191```
192
193## Monitoring
194
195The rate limiter tracks:
196
197- **Active limiters**: Number of users being tracked
198- **Requests allowed**: Total requests that passed
199- **Requests denied**: Total requests that were rate limited
200
201Access stats:
202
203```go
204stats := limiter.Stats()
205fmt.Printf("Active users: %d\n", stats.ActiveLimiters)
206fmt.Printf("Allowed: %d, Denied: %d\n", stats.Allowed, stats.Denied)
207fmt.Printf("Denial rate: %.2f%%\n", stats.DenialRate())
208```
209
210## Performance Considerations
211
212### Memory Usage
213
214Each tracked user (pubkey or IP) consumes ~200 bytes. With 10,000 active users:
215- Memory: ~2 MB
216- Lookup: O(1) with sync.RWMutex
217
218Idle limiters are cleaned up periodically (default: every 5 minutes).
219
220### Throughput
221
222Rate limiting adds minimal overhead:
223- Token check: ~100 nanoseconds
224- Lock contention: Read lock for lookups, write lock for new users only
225
226Benchmark results (on typical hardware):
227```
228BenchmarkRateLimitAllow-8 20000000 85 ns/op
229BenchmarkRateLimitDeny-8 20000000 82 ns/op
230```
231
232### Distributed Deployments
233
234This implementation is **in-memory** and works for single-instance deployments.
235
236For distributed deployments across multiple relay instances:
237
238**Option 1: Accept per-instance limits** (simplest)
239- Each instance tracks its own limits
240- Users get N × limit if they connect to N different instances
241- Usually acceptable for most use cases
242
243**Option 2: Shared Redis backend** (future enhancement)
244- Centralized rate limiting across all instances
245- Requires Redis dependency
246- Adds network latency (~1-2ms per request)
247
248**Option 3: Sticky sessions** (via load balancer)
249- Route users to the same instance
250- Per-instance limits become per-user limits
251- No coordination needed
252
253## Example: Relay with Tiered Access
254
255```go
256// Free tier: 10 req/s, strict publish limits
257// Premium tier: 50 req/s, relaxed limits
258// Admin tier: No limits
259
260func setupRateLimit() *ratelimit.Limiter {
261 return ratelimit.New(&ratelimit.Config{
262 // Free tier defaults
263 RequestsPerSecond: 10,
264 BurstSize: 20,
265
266 MethodLimits: map[string]ratelimit.MethodLimit{
267 "/nostr.v1.NostrRelay/PublishEvent": {
268 RequestsPerSecond: 2,
269 BurstSize: 5,
270 },
271 },
272
273 // Premium users
274 UserLimits: map[string]ratelimit.UserLimit{
275 "premium-user-1": {
276 RequestsPerSecond: 50,
277 BurstSize: 100,
278 },
279 },
280
281 // Admins bypass limits
282 SkipMethods: []string{},
283 SkipUsers: []string{
284 "admin-pubkey-abc",
285 },
286 })
287}
288```
289
290## Best Practices
291
2921. **Set conservative defaults**: Start with low limits and increase based on usage
2932. **Monitor denial rates**: High denial rates indicate limits are too strict
2943. **Method-specific tuning**: Writes (PublishEvent) should be stricter than reads
2954. **Burst allowance**: Set burst = 2-3× rate to handle legitimate traffic spikes
2965. **IP-based limits**: Set lower than authenticated limits to encourage auth
2976. **Cleanup interval**: Balance memory usage vs. repeated user setup overhead
298
299## Security Considerations
300
301### Rate Limit Bypass
302
303Rate limiting can be bypassed by:
304- Using multiple pubkeys (Sybil attack)
305- Using multiple IPs (distributed attack)
306
307Mitigations:
308- Require proof-of-work for new pubkeys
309- Monitor for suspicious patterns (many low-activity accounts)
310- Implement global rate limits in addition to per-user limits
311
312### DoS Protection
313
314Rate limiting helps with DoS but isn't sufficient alone:
315- Combine with connection limits
316- Implement request size limits
317- Use timeouts and deadlines
318- Consider L3/L4 DDoS protection (CloudFlare, etc.)
319
320## Integration with NIP-98 Auth
321
322Rate limiting works naturally with authentication:
323
324```
325Request flow:
3261. Request arrives
3272. Auth interceptor validates NIP-98 event → extracts pubkey
3283. Rate limit interceptor checks quota for pubkey
3294. If allowed → handler processes request
3305. If denied → return ResourceExhausted error
331```
332
333For unauthenticated requests:
334```
3351. Request arrives
3362. Auth interceptor allows (if Required: false)
3373. Rate limit interceptor uses IP address
3384. Check quota for IP → likely stricter limits
339```
340
341This encourages users to authenticate to get better rate limits!
diff --git a/internal/ratelimit/config.go b/internal/ratelimit/config.go
new file mode 100644
index 0000000..132c96b
--- /dev/null
+++ b/internal/ratelimit/config.go
@@ -0,0 +1,153 @@
1package ratelimit
2
3import "time"
4
5// Config configures the rate limiter behavior.
6type Config struct {
7 // RequestsPerSecond is the default rate limit in requests per second.
8 // This applies to authenticated users (identified by pubkey).
9 // Default: 10
10 RequestsPerSecond float64
11
12 // BurstSize is the maximum burst size (token bucket capacity).
13 // Allows users to make burst requests up to this limit.
14 // Default: 20
15 BurstSize int
16
17 // IPRequestsPerSecond is the rate limit for unauthenticated users.
18 // These are identified by IP address.
19 // Typically set lower than authenticated user limits.
20 // Default: 5
21 IPRequestsPerSecond float64
22
23 // IPBurstSize is the burst size for IP-based rate limiting.
24 // Default: 10
25 IPBurstSize int
26
27 // MethodLimits provides per-method rate limit overrides.
28 // Key is the full gRPC method name (e.g., "/nostr.v1.NostrRelay/PublishEvent")
29 // If not specified, uses the default RequestsPerSecond and BurstSize.
30 MethodLimits map[string]MethodLimit
31
32 // UserLimits provides per-user custom rate limits.
33 // Key is the pubkey. Useful for VIP/premium users or admins.
34 // If not specified, uses the default limits.
35 UserLimits map[string]UserLimit
36
37 // SkipMethods is a list of gRPC methods that bypass rate limiting.
38 // Useful for health checks or public endpoints.
39 // Example: []string{"/grpc.health.v1.Health/Check"}
40 SkipMethods []string
41
42 // SkipUsers is a list of pubkeys that bypass rate limiting.
43 // Useful for admins or monitoring services.
44 SkipUsers []string
45
46 // CleanupInterval is how often to remove idle rate limiters from memory.
47 // Limiters that haven't been used recently are removed to save memory.
48 // Default: 5 minutes
49 CleanupInterval time.Duration
50
51 // MaxIdleTime is how long a limiter can be idle before being cleaned up.
52 // Default: 10 minutes
53 MaxIdleTime time.Duration
54}
55
56// MethodLimit defines rate limits for a specific gRPC method.
57type MethodLimit struct {
58 RequestsPerSecond float64
59 BurstSize int
60}
61
62// UserLimit defines custom rate limits for a specific user (pubkey).
63type UserLimit struct {
64 // RequestsPerSecond is the default rate for this user.
65 RequestsPerSecond float64
66
67 // BurstSize is the burst size for this user.
68 BurstSize int
69
70 // MethodLimits provides per-method overrides for this user.
71 // Allows fine-grained control like "VIP user gets 100 req/s for queries
72 // but still only 5 req/s for publishes"
73 MethodLimits map[string]MethodLimit
74}
75
76// DefaultConfig returns the default rate limit configuration.
77func DefaultConfig() *Config {
78 return &Config{
79 RequestsPerSecond: 10,
80 BurstSize: 20,
81 IPRequestsPerSecond: 5,
82 IPBurstSize: 10,
83 CleanupInterval: 5 * time.Minute,
84 MaxIdleTime: 10 * time.Minute,
85 }
86}
87
88// Validate checks if the configuration is valid.
89func (c *Config) Validate() error {
90 if c.RequestsPerSecond <= 0 {
91 c.RequestsPerSecond = 10
92 }
93 if c.BurstSize <= 0 {
94 c.BurstSize = 20
95 }
96 if c.IPRequestsPerSecond <= 0 {
97 c.IPRequestsPerSecond = 5
98 }
99 if c.IPBurstSize <= 0 {
100 c.IPBurstSize = 10
101 }
102 if c.CleanupInterval <= 0 {
103 c.CleanupInterval = 5 * time.Minute
104 }
105 if c.MaxIdleTime <= 0 {
106 c.MaxIdleTime = 10 * time.Minute
107 }
108 return nil
109}
110
111// GetLimitForMethod returns the rate limit for a specific method and user.
112// Precedence: UserLimit.MethodLimit > MethodLimit > UserLimit > Default
113func (c *Config) GetLimitForMethod(pubkey, method string) (requestsPerSecond float64, burstSize int) {
114 // Check user-specific method limit first (highest precedence)
115 if userLimit, ok := c.UserLimits[pubkey]; ok {
116 if methodLimit, ok := userLimit.MethodLimits[method]; ok {
117 return methodLimit.RequestsPerSecond, methodLimit.BurstSize
118 }
119 }
120
121 // Check global method limit
122 if methodLimit, ok := c.MethodLimits[method]; ok {
123 return methodLimit.RequestsPerSecond, methodLimit.BurstSize
124 }
125
126 // Check user-specific default limit
127 if userLimit, ok := c.UserLimits[pubkey]; ok {
128 return userLimit.RequestsPerSecond, userLimit.BurstSize
129 }
130
131 // Fall back to global default
132 return c.RequestsPerSecond, c.BurstSize
133}
134
135// ShouldSkipMethod returns true if the method should bypass rate limiting.
136func (c *Config) ShouldSkipMethod(method string) bool {
137 for _, skip := range c.SkipMethods {
138 if skip == method {
139 return true
140 }
141 }
142 return false
143}
144
145// ShouldSkipUser returns true if the user should bypass rate limiting.
146func (c *Config) ShouldSkipUser(pubkey string) bool {
147 for _, skip := range c.SkipUsers {
148 if skip == pubkey {
149 return true
150 }
151 }
152 return false
153}
diff --git a/internal/ratelimit/interceptor.go b/internal/ratelimit/interceptor.go
new file mode 100644
index 0000000..b27fe7e
--- /dev/null
+++ b/internal/ratelimit/interceptor.go
@@ -0,0 +1,150 @@
1package ratelimit
2
3import (
4 "context"
5 "fmt"
6
7 "google.golang.org/grpc"
8 "google.golang.org/grpc/codes"
9 "google.golang.org/grpc/metadata"
10 "google.golang.org/grpc/peer"
11 "google.golang.org/grpc/status"
12
13 "northwest.io/muxstr/internal/auth"
14)
15
16// UnaryInterceptor creates a gRPC unary interceptor for rate limiting.
17// It should be chained after the auth interceptor so pubkey is available.
18func UnaryInterceptor(limiter *Limiter) grpc.UnaryServerInterceptor {
19 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
20 // Get identifier (pubkey or IP)
21 identifier := getIdentifier(ctx)
22
23 // Check rate limit
24 if !limiter.Allow(identifier, info.FullMethod) {
25 return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded for %s", identifier)
26 }
27
28 return handler(ctx, req)
29 }
30}
31
32// StreamInterceptor creates a gRPC stream interceptor for rate limiting.
33// It should be chained after the auth interceptor so pubkey is available.
34func StreamInterceptor(limiter *Limiter) grpc.StreamServerInterceptor {
35 return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
36 // Get identifier (pubkey or IP)
37 identifier := getIdentifier(ss.Context())
38
39 // Check rate limit
40 if !limiter.Allow(identifier, info.FullMethod) {
41 return status.Errorf(codes.ResourceExhausted, "rate limit exceeded for %s", identifier)
42 }
43
44 return handler(srv, ss)
45 }
46}
47
48// getIdentifier extracts the user identifier from the context.
49// Returns pubkey if authenticated, otherwise returns IP address.
50func getIdentifier(ctx context.Context) string {
51 // Try to get authenticated pubkey first
52 pubkey, ok := auth.PubkeyFromContext(ctx)
53 if ok && pubkey != "" {
54 return pubkey
55 }
56
57 // Fall back to IP address
58 return getIPAddress(ctx)
59}
60
61// getIPAddress extracts the client IP address from the context.
62func getIPAddress(ctx context.Context) string {
63 // Try to get from peer info
64 p, ok := peer.FromContext(ctx)
65 if ok && p.Addr != nil {
66 return p.Addr.String()
67 }
68
69 // Try to get from metadata (X-Forwarded-For header)
70 md, ok := metadata.FromIncomingContext(ctx)
71 if ok {
72 if xff := md.Get("x-forwarded-for"); len(xff) > 0 {
73 return xff[0]
74 }
75 if xri := md.Get("x-real-ip"); len(xri) > 0 {
76 return xri[0]
77 }
78 }
79
80 return "unknown"
81}
82
83// WaitUnaryInterceptor is a variant that waits instead of rejecting when rate limited.
84// Use this for critical operations that should never fail due to rate limiting.
85// WARNING: This can cause requests to hang if rate limit is never freed.
86func WaitUnaryInterceptor(limiter *Limiter) grpc.UnaryServerInterceptor {
87 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
88 identifier := getIdentifier(ctx)
89
90 // Get user limiters
91 userLims := limiter.getUserLimiters(identifier)
92 rateLimiter := userLims.getLimiterForMethod(info.FullMethod, limiter.config, identifier)
93
94 // Wait for permission (respects context deadline)
95 if err := rateLimiter.Wait(ctx); err != nil {
96 return nil, status.Errorf(codes.DeadlineExceeded, "rate limit wait cancelled: %v", err)
97 }
98
99 limiter.incrementAllowed()
100 return handler(ctx, req)
101 }
102}
103
104// RetryableError wraps a rate limit error with retry-after information.
105type RetryableError struct {
106 *status.Status
107 RetryAfter float64 // seconds
108}
109
110// Error implements the error interface.
111func (e *RetryableError) Error() string {
112 return fmt.Sprintf("%s (retry after %.1fs)", e.Status.Message(), e.RetryAfter)
113}
114
115// UnaryInterceptorWithRetryAfter is like UnaryInterceptor but includes retry-after info.
116// Clients can extract this to implement smart backoff.
117func UnaryInterceptorWithRetryAfter(limiter *Limiter) grpc.UnaryServerInterceptor {
118 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
119 identifier := getIdentifier(ctx)
120
121 // Get user limiters
122 userLims := limiter.getUserLimiters(identifier)
123 rateLimiter := userLims.getLimiterForMethod(info.FullMethod, limiter.config, identifier)
124
125 // Get reservation to check how long to wait
126 reservation := rateLimiter.Reserve()
127 if !reservation.OK() {
128 return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded (burst exhausted)")
129 }
130
131 delay := reservation.Delay()
132 if delay > 0 {
133 // Cancel the reservation since we're not going to wait
134 reservation.Cancel()
135
136 limiter.incrementDenied()
137
138 // Return error with retry-after information
139 st := status.New(codes.ResourceExhausted, fmt.Sprintf("rate limit exceeded for %s", identifier))
140 return nil, &RetryableError{
141 Status: st,
142 RetryAfter: delay.Seconds(),
143 }
144 }
145
146 // No delay needed, proceed
147 limiter.incrementAllowed()
148 return handler(ctx, req)
149 }
150}
diff --git a/internal/ratelimit/limiter.go b/internal/ratelimit/limiter.go
new file mode 100644
index 0000000..9d8c799
--- /dev/null
+++ b/internal/ratelimit/limiter.go
@@ -0,0 +1,279 @@
1package ratelimit
2
3import (
4 "sync"
5 "time"
6
7 "golang.org/x/time/rate"
8)
9
10// Limiter manages per-user rate limiting using the token bucket algorithm.
11type Limiter struct {
12 config *Config
13
14 // limiters maps identifier (pubkey or IP) to method-specific limiters
15 limiters map[string]*userLimiters
16 mu sync.RWMutex
17
18 // stats tracks metrics
19 stats Stats
20
21 // cleanup manages periodic cleanup of idle limiters
22 stopCleanup chan struct{}
23}
24
25// userLimiters holds rate limiters for a single user (pubkey or IP)
26type userLimiters struct {
27 // default limiter for methods without specific limits
28 defaultLimiter *rate.Limiter
29
30 // method-specific limiters
31 methodLimiters map[string]*rate.Limiter
32
33 // last access time for cleanup
34 lastAccess time.Time
35 mu sync.RWMutex
36}
37
38// Stats holds rate limiter statistics.
39type Stats struct {
40 ActiveLimiters int64 // Number of active users being tracked
41 Allowed int64 // Total requests allowed
42 Denied int64 // Total requests denied
43 mu sync.RWMutex
44}
45
46// DenialRate returns the percentage of requests denied.
47func (s *Stats) DenialRate() float64 {
48 s.mu.RLock()
49 defer s.mu.RUnlock()
50
51 total := s.Allowed + s.Denied
52 if total == 0 {
53 return 0
54 }
55 return float64(s.Denied) / float64(total) * 100
56}
57
58// New creates a new rate limiter with the given configuration.
59func New(config *Config) *Limiter {
60 if config == nil {
61 config = DefaultConfig()
62 }
63 config.Validate()
64
65 l := &Limiter{
66 config: config,
67 limiters: make(map[string]*userLimiters),
68 stopCleanup: make(chan struct{}),
69 }
70
71 // Start cleanup goroutine
72 go l.cleanupLoop()
73
74 return l
75}
76
77// Allow checks if a request should be allowed for the given identifier and method.
78// identifier is either a pubkey (for authenticated users) or IP address.
79// method is the full gRPC method name.
80func (l *Limiter) Allow(identifier, method string) bool {
81 // Check if method should be skipped
82 if l.config.ShouldSkipMethod(method) {
83 l.incrementAllowed()
84 return true
85 }
86
87 // Check if user should be skipped
88 if l.config.ShouldSkipUser(identifier) {
89 l.incrementAllowed()
90 return true
91 }
92
93 // Get or create user limiters
94 userLims := l.getUserLimiters(identifier)
95
96 // Get method-specific limiter
97 limiter := userLims.getLimiterForMethod(method, l.config, identifier)
98
99 // Check if request is allowed
100 if limiter.Allow() {
101 l.incrementAllowed()
102 return true
103 }
104
105 l.incrementDenied()
106 return false
107}
108
109// getUserLimiters gets or creates the limiters for a user.
110func (l *Limiter) getUserLimiters(identifier string) *userLimiters {
111 // Try read lock first (fast path)
112 l.mu.RLock()
113 userLims, ok := l.limiters[identifier]
114 l.mu.RUnlock()
115
116 if ok {
117 userLims.updateLastAccess()
118 return userLims
119 }
120
121 // Need to create new limiters (slow path)
122 l.mu.Lock()
123 defer l.mu.Unlock()
124
125 // Double-check after acquiring write lock
126 userLims, ok = l.limiters[identifier]
127 if ok {
128 userLims.updateLastAccess()
129 return userLims
130 }
131
132 // Create new user limiters
133 userLims = &userLimiters{
134 methodLimiters: make(map[string]*rate.Limiter),
135 lastAccess: time.Now(),
136 }
137
138 l.limiters[identifier] = userLims
139 l.incrementActiveLimiters()
140
141 return userLims
142}
143
144// getLimiterForMethod gets the rate limiter for a specific method.
145func (u *userLimiters) getLimiterForMethod(method string, config *Config, identifier string) *rate.Limiter {
146 u.mu.RLock()
147 limiter, ok := u.methodLimiters[method]
148 u.mu.RUnlock()
149
150 if ok {
151 return limiter
152 }
153
154 // Create new limiter for this method
155 u.mu.Lock()
156 defer u.mu.Unlock()
157
158 // Double-check after acquiring write lock
159 limiter, ok = u.methodLimiters[method]
160 if ok {
161 return limiter
162 }
163
164 // Get rate limit for this method and user
165 rps, burst := config.GetLimitForMethod(identifier, method)
166
167 // Create new rate limiter
168 limiter = rate.NewLimiter(rate.Limit(rps), burst)
169 u.methodLimiters[method] = limiter
170
171 return limiter
172}
173
174// updateLastAccess updates the last access time for this user.
175func (u *userLimiters) updateLastAccess() {
176 u.mu.Lock()
177 u.lastAccess = time.Now()
178 u.mu.Unlock()
179}
180
181// isIdle returns true if this user hasn't been accessed recently.
182func (u *userLimiters) isIdle(maxIdleTime time.Duration) bool {
183 u.mu.RLock()
184 defer u.mu.RUnlock()
185 return time.Since(u.lastAccess) > maxIdleTime
186}
187
188// cleanupLoop periodically removes idle limiters to free memory.
189func (l *Limiter) cleanupLoop() {
190 ticker := time.NewTicker(l.config.CleanupInterval)
191 defer ticker.Stop()
192
193 for {
194 select {
195 case <-ticker.C:
196 l.cleanup()
197 case <-l.stopCleanup:
198 return
199 }
200 }
201}
202
203// cleanup removes idle limiters from memory.
204func (l *Limiter) cleanup() {
205 l.mu.Lock()
206 defer l.mu.Unlock()
207
208 removed := 0
209
210 for identifier, userLims := range l.limiters {
211 if userLims.isIdle(l.config.MaxIdleTime) {
212 delete(l.limiters, identifier)
213 removed++
214 }
215 }
216
217 if removed > 0 {
218 l.stats.mu.Lock()
219 l.stats.ActiveLimiters -= int64(removed)
220 l.stats.mu.Unlock()
221 }
222}
223
224// Stop stops the cleanup goroutine.
225func (l *Limiter) Stop() {
226 close(l.stopCleanup)
227}
228
229// Stats returns current rate limiter statistics.
230func (l *Limiter) Stats() Stats {
231 l.stats.mu.RLock()
232 defer l.stats.mu.RUnlock()
233
234 // Update active limiters count
235 l.mu.RLock()
236 activeLimiters := int64(len(l.limiters))
237 l.mu.RUnlock()
238
239 return Stats{
240 ActiveLimiters: activeLimiters,
241 Allowed: l.stats.Allowed,
242 Denied: l.stats.Denied,
243 }
244}
245
246// incrementAllowed increments the allowed counter.
247func (l *Limiter) incrementAllowed() {
248 l.stats.mu.Lock()
249 l.stats.Allowed++
250 l.stats.mu.Unlock()
251}
252
253// incrementDenied increments the denied counter.
254func (l *Limiter) incrementDenied() {
255 l.stats.mu.Lock()
256 l.stats.Denied++
257 l.stats.mu.Unlock()
258}
259
260// incrementActiveLimiters increments the active limiters counter.
261func (l *Limiter) incrementActiveLimiters() {
262 l.stats.mu.Lock()
263 l.stats.ActiveLimiters++
264 l.stats.mu.Unlock()
265}
266
267// Reset clears all rate limiters and resets statistics.
268// Useful for testing.
269func (l *Limiter) Reset() {
270 l.mu.Lock()
271 l.limiters = make(map[string]*userLimiters)
272 l.mu.Unlock()
273
274 l.stats.mu.Lock()
275 l.stats.ActiveLimiters = 0
276 l.stats.Allowed = 0
277 l.stats.Denied = 0
278 l.stats.mu.Unlock()
279}
diff --git a/internal/ratelimit/ratelimit_test.go b/internal/ratelimit/ratelimit_test.go
new file mode 100644
index 0000000..963d97f
--- /dev/null
+++ b/internal/ratelimit/ratelimit_test.go
@@ -0,0 +1,438 @@
1package ratelimit
2
3import (
4 "context"
5 "testing"
6 "time"
7
8 "google.golang.org/grpc"
9 "google.golang.org/grpc/codes"
10 "google.golang.org/grpc/metadata"
11 "google.golang.org/grpc/status"
12)
13
14func TestBasicRateLimit(t *testing.T) {
15 config := &Config{
16 RequestsPerSecond: 10,
17 BurstSize: 10,
18 }
19
20 limiter := New(config)
21 defer limiter.Stop()
22
23 identifier := "test-user"
24 method := "/test.Service/Method"
25
26 // First 10 requests should succeed (burst)
27 for i := 0; i < 10; i++ {
28 if !limiter.Allow(identifier, method) {
29 t.Errorf("request %d should be allowed", i)
30 }
31 }
32
33 // 11th request should be denied (burst exhausted)
34 if limiter.Allow(identifier, method) {
35 t.Error("request 11 should be denied")
36 }
37
38 // Wait for tokens to refill
39 time.Sleep(150 * time.Millisecond)
40
41 // Should allow 1 more request (1 token refilled)
42 if !limiter.Allow(identifier, method) {
43 t.Error("request after refill should be allowed")
44 }
45}
46
47func TestPerUserLimits(t *testing.T) {
48 config := &Config{
49 RequestsPerSecond: 10,
50 BurstSize: 10,
51 }
52
53 limiter := New(config)
54 defer limiter.Stop()
55
56 method := "/test.Service/Method"
57
58 // Different users should have independent limits
59 user1 := "user1"
60 user2 := "user2"
61
62 // Exhaust user1's quota
63 for i := 0; i < 10; i++ {
64 limiter.Allow(user1, method)
65 }
66
67 // User1 should be denied
68 if limiter.Allow(user1, method) {
69 t.Error("user1 should be rate limited")
70 }
71
72 // User2 should still be allowed
73 if !limiter.Allow(user2, method) {
74 t.Error("user2 should not be rate limited")
75 }
76}
77
78func TestMethodSpecificLimits(t *testing.T) {
79 config := &Config{
80 RequestsPerSecond: 10,
81 BurstSize: 10,
82 MethodLimits: map[string]MethodLimit{
83 "/test.Service/StrictMethod": {
84 RequestsPerSecond: 2,
85 BurstSize: 2,
86 },
87 },
88 }
89
90 limiter := New(config)
91 defer limiter.Stop()
92
93 identifier := "test-user"
94
95 // Regular method should allow 10 requests
96 regularMethod := "/test.Service/RegularMethod"
97 for i := 0; i < 10; i++ {
98 if !limiter.Allow(identifier, regularMethod) {
99 t.Errorf("regular method request %d should be allowed", i)
100 }
101 }
102
103 // Strict method should only allow 2 requests
104 strictMethod := "/test.Service/StrictMethod"
105 for i := 0; i < 2; i++ {
106 if !limiter.Allow(identifier, strictMethod) {
107 t.Errorf("strict method request %d should be allowed", i)
108 }
109 }
110
111 // 3rd request should be denied
112 if limiter.Allow(identifier, strictMethod) {
113 t.Error("strict method request 3 should be denied")
114 }
115}
116
117func TestUserSpecificLimits(t *testing.T) {
118 config := &Config{
119 RequestsPerSecond: 10,
120 BurstSize: 10,
121 UserLimits: map[string]UserLimit{
122 "vip-user": {
123 RequestsPerSecond: 100,
124 BurstSize: 100,
125 },
126 },
127 }
128
129 limiter := New(config)
130 defer limiter.Stop()
131
132 method := "/test.Service/Method"
133
134 // Regular user should be limited to 10
135 regularUser := "regular-user"
136 for i := 0; i < 10; i++ {
137 limiter.Allow(regularUser, method)
138 }
139 if limiter.Allow(regularUser, method) {
140 t.Error("regular user should be rate limited")
141 }
142
143 // VIP user should allow 100
144 vipUser := "vip-user"
145 for i := 0; i < 100; i++ {
146 if !limiter.Allow(vipUser, method) {
147 t.Errorf("vip user request %d should be allowed", i)
148 }
149 }
150}
151
152func TestSkipMethods(t *testing.T) {
153 config := &Config{
154 RequestsPerSecond: 1,
155 BurstSize: 1,
156 SkipMethods: []string{
157 "/health/Check",
158 },
159 }
160
161 limiter := New(config)
162 defer limiter.Stop()
163
164 identifier := "test-user"
165
166 // Regular method should be rate limited
167 regularMethod := "/test.Service/Method"
168 limiter.Allow(identifier, regularMethod)
169 if limiter.Allow(identifier, regularMethod) {
170 t.Error("regular method should be rate limited")
171 }
172
173 // Skipped method should never be rate limited
174 skipMethod := "/health/Check"
175 for i := 0; i < 100; i++ {
176 if !limiter.Allow(identifier, skipMethod) {
177 t.Error("skipped method should never be rate limited")
178 }
179 }
180}
181
182func TestSkipUsers(t *testing.T) {
183 config := &Config{
184 RequestsPerSecond: 1,
185 BurstSize: 1,
186 SkipUsers: []string{
187 "admin-user",
188 },
189 }
190
191 limiter := New(config)
192 defer limiter.Stop()
193
194 method := "/test.Service/Method"
195
196 // Regular user should be rate limited
197 regularUser := "regular-user"
198 limiter.Allow(regularUser, method)
199 if limiter.Allow(regularUser, method) {
200 t.Error("regular user should be rate limited")
201 }
202
203 // Admin user should never be rate limited
204 adminUser := "admin-user"
205 for i := 0; i < 100; i++ {
206 if !limiter.Allow(adminUser, method) {
207 t.Error("admin user should never be rate limited")
208 }
209 }
210}
211
212func TestStats(t *testing.T) {
213 config := &Config{
214 RequestsPerSecond: 10,
215 BurstSize: 5,
216 }
217
218 limiter := New(config)
219 defer limiter.Stop()
220
221 identifier := "test-user"
222 method := "/test.Service/Method"
223
224 // Make some requests
225 for i := 0; i < 5; i++ {
226 limiter.Allow(identifier, method) // All allowed (within burst)
227 }
228 for i := 0; i < 3; i++ {
229 limiter.Allow(identifier, method) // All denied (burst exhausted)
230 }
231
232 stats := limiter.Stats()
233
234 if stats.Allowed != 5 {
235 t.Errorf("expected 5 allowed, got %d", stats.Allowed)
236 }
237 if stats.Denied != 3 {
238 t.Errorf("expected 3 denied, got %d", stats.Denied)
239 }
240 if stats.ActiveLimiters != 1 {
241 t.Errorf("expected 1 active limiter, got %d", stats.ActiveLimiters)
242 }
243
244 expectedDenialRate := 37.5 // 3/8 * 100
245 if stats.DenialRate() != expectedDenialRate {
246 t.Errorf("expected denial rate %.1f%%, got %.1f%%", expectedDenialRate, stats.DenialRate())
247 }
248}
249
250func TestCleanup(t *testing.T) {
251 config := &Config{
252 RequestsPerSecond: 10,
253 BurstSize: 10,
254 CleanupInterval: 100 * time.Millisecond,
255 MaxIdleTime: 200 * time.Millisecond,
256 }
257
258 limiter := New(config)
259 defer limiter.Stop()
260
261 // Create limiters for multiple users
262 for i := 0; i < 5; i++ {
263 limiter.Allow("user-"+string(rune('0'+i)), "/test")
264 }
265
266 stats := limiter.Stats()
267 if stats.ActiveLimiters != 5 {
268 t.Errorf("expected 5 active limiters, got %d", stats.ActiveLimiters)
269 }
270
271 // Wait for cleanup to run
272 time.Sleep(350 * time.Millisecond)
273
274 stats = limiter.Stats()
275 if stats.ActiveLimiters != 0 {
276 t.Errorf("expected 0 active limiters after cleanup, got %d", stats.ActiveLimiters)
277 }
278}
279
280func TestUnaryInterceptor(t *testing.T) {
281 config := &Config{
282 RequestsPerSecond: 2,
283 BurstSize: 2,
284 }
285
286 limiter := New(config)
287 defer limiter.Stop()
288
289 interceptor := UnaryInterceptor(limiter)
290
291 // Create a test handler
292 handler := func(ctx context.Context, req interface{}) (interface{}, error) {
293 return "success", nil
294 }
295
296 info := &grpc.UnaryServerInfo{
297 FullMethod: "/test.Service/Method",
298 }
299
300 // Create context with metadata (simulating IP)
301 md := metadata.Pairs("x-real-ip", "192.168.1.1")
302 ctx := metadata.NewIncomingContext(context.Background(), md)
303
304 // First 2 requests should succeed
305 for i := 0; i < 2; i++ {
306 _, err := interceptor(ctx, nil, info, handler)
307 if err != nil {
308 t.Errorf("request %d should succeed, got error: %v", i, err)
309 }
310 }
311
312 // 3rd request should be rate limited
313 _, err := interceptor(ctx, nil, info, handler)
314 if err == nil {
315 t.Error("expected rate limit error")
316 }
317
318 st, ok := status.FromError(err)
319 if !ok {
320 t.Error("expected gRPC status error")
321 }
322 if st.Code() != codes.ResourceExhausted {
323 t.Errorf("expected ResourceExhausted, got %v", st.Code())
324 }
325}
326
327func TestGetLimitForMethod(t *testing.T) {
328 config := &Config{
329 RequestsPerSecond: 10,
330 BurstSize: 20,
331 MethodLimits: map[string]MethodLimit{
332 "/test/Method1": {
333 RequestsPerSecond: 5,
334 BurstSize: 10,
335 },
336 },
337 UserLimits: map[string]UserLimit{
338 "vip-user": {
339 RequestsPerSecond: 50,
340 BurstSize: 100,
341 MethodLimits: map[string]MethodLimit{
342 "/test/Method1": {
343 RequestsPerSecond: 25,
344 BurstSize: 50,
345 },
346 },
347 },
348 },
349 }
350
351 tests := []struct {
352 name string
353 pubkey string
354 method string
355 expectedRPS float64
356 expectedBurst int
357 }{
358 {
359 name: "default for regular user",
360 pubkey: "regular-user",
361 method: "/test/Method2",
362 expectedRPS: 10,
363 expectedBurst: 20,
364 },
365 {
366 name: "method limit for regular user",
367 pubkey: "regular-user",
368 method: "/test/Method1",
369 expectedRPS: 5,
370 expectedBurst: 10,
371 },
372 {
373 name: "user limit default method",
374 pubkey: "vip-user",
375 method: "/test/Method2",
376 expectedRPS: 50,
377 expectedBurst: 100,
378 },
379 {
380 name: "user method limit (highest precedence)",
381 pubkey: "vip-user",
382 method: "/test/Method1",
383 expectedRPS: 25,
384 expectedBurst: 50,
385 },
386 }
387
388 for _, tt := range tests {
389 t.Run(tt.name, func(t *testing.T) {
390 rps, burst := config.GetLimitForMethod(tt.pubkey, tt.method)
391 if rps != tt.expectedRPS {
392 t.Errorf("expected RPS %.1f, got %.1f", tt.expectedRPS, rps)
393 }
394 if burst != tt.expectedBurst {
395 t.Errorf("expected burst %d, got %d", tt.expectedBurst, burst)
396 }
397 })
398 }
399}
400
401func BenchmarkRateLimitAllow(b *testing.B) {
402 config := &Config{
403 RequestsPerSecond: 1000,
404 BurstSize: 1000,
405 }
406
407 limiter := New(config)
408 defer limiter.Stop()
409
410 identifier := "bench-user"
411 method := "/test.Service/Method"
412
413 b.ResetTimer()
414 for i := 0; i < b.N; i++ {
415 limiter.Allow(identifier, method)
416 }
417}
418
419func BenchmarkRateLimitDeny(b *testing.B) {
420 config := &Config{
421 RequestsPerSecond: 1,
422 BurstSize: 1,
423 }
424
425 limiter := New(config)
426 defer limiter.Stop()
427
428 identifier := "bench-user"
429 method := "/test.Service/Method"
430
431 // Exhaust quota
432 limiter.Allow(identifier, method)
433
434 b.ResetTimer()
435 for i := 0; i < b.N; i++ {
436 limiter.Allow(identifier, method)
437 }
438}