summaryrefslogtreecommitdiffstats
path: root/internal/ratelimit/limiter.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/ratelimit/limiter.go')
-rw-r--r--internal/ratelimit/limiter.go279
1 files changed, 279 insertions, 0 deletions
diff --git a/internal/ratelimit/limiter.go b/internal/ratelimit/limiter.go
new file mode 100644
index 0000000..9d8c799
--- /dev/null
+++ b/internal/ratelimit/limiter.go
@@ -0,0 +1,279 @@
1package ratelimit
2
3import (
4 "sync"
5 "time"
6
7 "golang.org/x/time/rate"
8)
9
10// Limiter manages per-user rate limiting using the token bucket algorithm.
11type Limiter struct {
12 config *Config
13
14 // limiters maps identifier (pubkey or IP) to method-specific limiters
15 limiters map[string]*userLimiters
16 mu sync.RWMutex
17
18 // stats tracks metrics
19 stats Stats
20
21 // cleanup manages periodic cleanup of idle limiters
22 stopCleanup chan struct{}
23}
24
25// userLimiters holds rate limiters for a single user (pubkey or IP)
26type userLimiters struct {
27 // default limiter for methods without specific limits
28 defaultLimiter *rate.Limiter
29
30 // method-specific limiters
31 methodLimiters map[string]*rate.Limiter
32
33 // last access time for cleanup
34 lastAccess time.Time
35 mu sync.RWMutex
36}
37
38// Stats holds rate limiter statistics.
39type Stats struct {
40 ActiveLimiters int64 // Number of active users being tracked
41 Allowed int64 // Total requests allowed
42 Denied int64 // Total requests denied
43 mu sync.RWMutex
44}
45
46// DenialRate returns the percentage of requests denied.
47func (s *Stats) DenialRate() float64 {
48 s.mu.RLock()
49 defer s.mu.RUnlock()
50
51 total := s.Allowed + s.Denied
52 if total == 0 {
53 return 0
54 }
55 return float64(s.Denied) / float64(total) * 100
56}
57
58// New creates a new rate limiter with the given configuration.
59func New(config *Config) *Limiter {
60 if config == nil {
61 config = DefaultConfig()
62 }
63 config.Validate()
64
65 l := &Limiter{
66 config: config,
67 limiters: make(map[string]*userLimiters),
68 stopCleanup: make(chan struct{}),
69 }
70
71 // Start cleanup goroutine
72 go l.cleanupLoop()
73
74 return l
75}
76
77// Allow checks if a request should be allowed for the given identifier and method.
78// identifier is either a pubkey (for authenticated users) or IP address.
79// method is the full gRPC method name.
80func (l *Limiter) Allow(identifier, method string) bool {
81 // Check if method should be skipped
82 if l.config.ShouldSkipMethod(method) {
83 l.incrementAllowed()
84 return true
85 }
86
87 // Check if user should be skipped
88 if l.config.ShouldSkipUser(identifier) {
89 l.incrementAllowed()
90 return true
91 }
92
93 // Get or create user limiters
94 userLims := l.getUserLimiters(identifier)
95
96 // Get method-specific limiter
97 limiter := userLims.getLimiterForMethod(method, l.config, identifier)
98
99 // Check if request is allowed
100 if limiter.Allow() {
101 l.incrementAllowed()
102 return true
103 }
104
105 l.incrementDenied()
106 return false
107}
108
109// getUserLimiters gets or creates the limiters for a user.
110func (l *Limiter) getUserLimiters(identifier string) *userLimiters {
111 // Try read lock first (fast path)
112 l.mu.RLock()
113 userLims, ok := l.limiters[identifier]
114 l.mu.RUnlock()
115
116 if ok {
117 userLims.updateLastAccess()
118 return userLims
119 }
120
121 // Need to create new limiters (slow path)
122 l.mu.Lock()
123 defer l.mu.Unlock()
124
125 // Double-check after acquiring write lock
126 userLims, ok = l.limiters[identifier]
127 if ok {
128 userLims.updateLastAccess()
129 return userLims
130 }
131
132 // Create new user limiters
133 userLims = &userLimiters{
134 methodLimiters: make(map[string]*rate.Limiter),
135 lastAccess: time.Now(),
136 }
137
138 l.limiters[identifier] = userLims
139 l.incrementActiveLimiters()
140
141 return userLims
142}
143
144// getLimiterForMethod gets the rate limiter for a specific method.
145func (u *userLimiters) getLimiterForMethod(method string, config *Config, identifier string) *rate.Limiter {
146 u.mu.RLock()
147 limiter, ok := u.methodLimiters[method]
148 u.mu.RUnlock()
149
150 if ok {
151 return limiter
152 }
153
154 // Create new limiter for this method
155 u.mu.Lock()
156 defer u.mu.Unlock()
157
158 // Double-check after acquiring write lock
159 limiter, ok = u.methodLimiters[method]
160 if ok {
161 return limiter
162 }
163
164 // Get rate limit for this method and user
165 rps, burst := config.GetLimitForMethod(identifier, method)
166
167 // Create new rate limiter
168 limiter = rate.NewLimiter(rate.Limit(rps), burst)
169 u.methodLimiters[method] = limiter
170
171 return limiter
172}
173
174// updateLastAccess updates the last access time for this user.
175func (u *userLimiters) updateLastAccess() {
176 u.mu.Lock()
177 u.lastAccess = time.Now()
178 u.mu.Unlock()
179}
180
181// isIdle returns true if this user hasn't been accessed recently.
182func (u *userLimiters) isIdle(maxIdleTime time.Duration) bool {
183 u.mu.RLock()
184 defer u.mu.RUnlock()
185 return time.Since(u.lastAccess) > maxIdleTime
186}
187
188// cleanupLoop periodically removes idle limiters to free memory.
189func (l *Limiter) cleanupLoop() {
190 ticker := time.NewTicker(l.config.CleanupInterval)
191 defer ticker.Stop()
192
193 for {
194 select {
195 case <-ticker.C:
196 l.cleanup()
197 case <-l.stopCleanup:
198 return
199 }
200 }
201}
202
203// cleanup removes idle limiters from memory.
204func (l *Limiter) cleanup() {
205 l.mu.Lock()
206 defer l.mu.Unlock()
207
208 removed := 0
209
210 for identifier, userLims := range l.limiters {
211 if userLims.isIdle(l.config.MaxIdleTime) {
212 delete(l.limiters, identifier)
213 removed++
214 }
215 }
216
217 if removed > 0 {
218 l.stats.mu.Lock()
219 l.stats.ActiveLimiters -= int64(removed)
220 l.stats.mu.Unlock()
221 }
222}
223
224// Stop stops the cleanup goroutine.
225func (l *Limiter) Stop() {
226 close(l.stopCleanup)
227}
228
229// Stats returns current rate limiter statistics.
230func (l *Limiter) Stats() Stats {
231 l.stats.mu.RLock()
232 defer l.stats.mu.RUnlock()
233
234 // Update active limiters count
235 l.mu.RLock()
236 activeLimiters := int64(len(l.limiters))
237 l.mu.RUnlock()
238
239 return Stats{
240 ActiveLimiters: activeLimiters,
241 Allowed: l.stats.Allowed,
242 Denied: l.stats.Denied,
243 }
244}
245
246// incrementAllowed increments the allowed counter.
247func (l *Limiter) incrementAllowed() {
248 l.stats.mu.Lock()
249 l.stats.Allowed++
250 l.stats.mu.Unlock()
251}
252
253// incrementDenied increments the denied counter.
254func (l *Limiter) incrementDenied() {
255 l.stats.mu.Lock()
256 l.stats.Denied++
257 l.stats.mu.Unlock()
258}
259
260// incrementActiveLimiters increments the active limiters counter.
261func (l *Limiter) incrementActiveLimiters() {
262 l.stats.mu.Lock()
263 l.stats.ActiveLimiters++
264 l.stats.mu.Unlock()
265}
266
267// Reset clears all rate limiters and resets statistics.
268// Useful for testing.
269func (l *Limiter) Reset() {
270 l.mu.Lock()
271 l.limiters = make(map[string]*userLimiters)
272 l.mu.Unlock()
273
274 l.stats.mu.Lock()
275 l.stats.ActiveLimiters = 0
276 l.stats.Allowed = 0
277 l.stats.Denied = 0
278 l.stats.mu.Unlock()
279}