diff options
Diffstat (limited to 'internal/ratelimit/limiter.go')
| -rw-r--r-- | internal/ratelimit/limiter.go | 279 |
1 files changed, 279 insertions, 0 deletions
diff --git a/internal/ratelimit/limiter.go b/internal/ratelimit/limiter.go new file mode 100644 index 0000000..9d8c799 --- /dev/null +++ b/internal/ratelimit/limiter.go | |||
| @@ -0,0 +1,279 @@ | |||
| 1 | package ratelimit | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "sync" | ||
| 5 | "time" | ||
| 6 | |||
| 7 | "golang.org/x/time/rate" | ||
| 8 | ) | ||
| 9 | |||
| 10 | // Limiter manages per-user rate limiting using the token bucket algorithm. | ||
| 11 | type Limiter struct { | ||
| 12 | config *Config | ||
| 13 | |||
| 14 | // limiters maps identifier (pubkey or IP) to method-specific limiters | ||
| 15 | limiters map[string]*userLimiters | ||
| 16 | mu sync.RWMutex | ||
| 17 | |||
| 18 | // stats tracks metrics | ||
| 19 | stats Stats | ||
| 20 | |||
| 21 | // cleanup manages periodic cleanup of idle limiters | ||
| 22 | stopCleanup chan struct{} | ||
| 23 | } | ||
| 24 | |||
| 25 | // userLimiters holds rate limiters for a single user (pubkey or IP) | ||
| 26 | type userLimiters struct { | ||
| 27 | // default limiter for methods without specific limits | ||
| 28 | defaultLimiter *rate.Limiter | ||
| 29 | |||
| 30 | // method-specific limiters | ||
| 31 | methodLimiters map[string]*rate.Limiter | ||
| 32 | |||
| 33 | // last access time for cleanup | ||
| 34 | lastAccess time.Time | ||
| 35 | mu sync.RWMutex | ||
| 36 | } | ||
| 37 | |||
| 38 | // Stats holds rate limiter statistics. | ||
| 39 | type Stats struct { | ||
| 40 | ActiveLimiters int64 // Number of active users being tracked | ||
| 41 | Allowed int64 // Total requests allowed | ||
| 42 | Denied int64 // Total requests denied | ||
| 43 | mu sync.RWMutex | ||
| 44 | } | ||
| 45 | |||
| 46 | // DenialRate returns the percentage of requests denied. | ||
| 47 | func (s *Stats) DenialRate() float64 { | ||
| 48 | s.mu.RLock() | ||
| 49 | defer s.mu.RUnlock() | ||
| 50 | |||
| 51 | total := s.Allowed + s.Denied | ||
| 52 | if total == 0 { | ||
| 53 | return 0 | ||
| 54 | } | ||
| 55 | return float64(s.Denied) / float64(total) * 100 | ||
| 56 | } | ||
| 57 | |||
| 58 | // New creates a new rate limiter with the given configuration. | ||
| 59 | func New(config *Config) *Limiter { | ||
| 60 | if config == nil { | ||
| 61 | config = DefaultConfig() | ||
| 62 | } | ||
| 63 | config.Validate() | ||
| 64 | |||
| 65 | l := &Limiter{ | ||
| 66 | config: config, | ||
| 67 | limiters: make(map[string]*userLimiters), | ||
| 68 | stopCleanup: make(chan struct{}), | ||
| 69 | } | ||
| 70 | |||
| 71 | // Start cleanup goroutine | ||
| 72 | go l.cleanupLoop() | ||
| 73 | |||
| 74 | return l | ||
| 75 | } | ||
| 76 | |||
| 77 | // Allow checks if a request should be allowed for the given identifier and method. | ||
| 78 | // identifier is either a pubkey (for authenticated users) or IP address. | ||
| 79 | // method is the full gRPC method name. | ||
| 80 | func (l *Limiter) Allow(identifier, method string) bool { | ||
| 81 | // Check if method should be skipped | ||
| 82 | if l.config.ShouldSkipMethod(method) { | ||
| 83 | l.incrementAllowed() | ||
| 84 | return true | ||
| 85 | } | ||
| 86 | |||
| 87 | // Check if user should be skipped | ||
| 88 | if l.config.ShouldSkipUser(identifier) { | ||
| 89 | l.incrementAllowed() | ||
| 90 | return true | ||
| 91 | } | ||
| 92 | |||
| 93 | // Get or create user limiters | ||
| 94 | userLims := l.getUserLimiters(identifier) | ||
| 95 | |||
| 96 | // Get method-specific limiter | ||
| 97 | limiter := userLims.getLimiterForMethod(method, l.config, identifier) | ||
| 98 | |||
| 99 | // Check if request is allowed | ||
| 100 | if limiter.Allow() { | ||
| 101 | l.incrementAllowed() | ||
| 102 | return true | ||
| 103 | } | ||
| 104 | |||
| 105 | l.incrementDenied() | ||
| 106 | return false | ||
| 107 | } | ||
| 108 | |||
| 109 | // getUserLimiters gets or creates the limiters for a user. | ||
| 110 | func (l *Limiter) getUserLimiters(identifier string) *userLimiters { | ||
| 111 | // Try read lock first (fast path) | ||
| 112 | l.mu.RLock() | ||
| 113 | userLims, ok := l.limiters[identifier] | ||
| 114 | l.mu.RUnlock() | ||
| 115 | |||
| 116 | if ok { | ||
| 117 | userLims.updateLastAccess() | ||
| 118 | return userLims | ||
| 119 | } | ||
| 120 | |||
| 121 | // Need to create new limiters (slow path) | ||
| 122 | l.mu.Lock() | ||
| 123 | defer l.mu.Unlock() | ||
| 124 | |||
| 125 | // Double-check after acquiring write lock | ||
| 126 | userLims, ok = l.limiters[identifier] | ||
| 127 | if ok { | ||
| 128 | userLims.updateLastAccess() | ||
| 129 | return userLims | ||
| 130 | } | ||
| 131 | |||
| 132 | // Create new user limiters | ||
| 133 | userLims = &userLimiters{ | ||
| 134 | methodLimiters: make(map[string]*rate.Limiter), | ||
| 135 | lastAccess: time.Now(), | ||
| 136 | } | ||
| 137 | |||
| 138 | l.limiters[identifier] = userLims | ||
| 139 | l.incrementActiveLimiters() | ||
| 140 | |||
| 141 | return userLims | ||
| 142 | } | ||
| 143 | |||
| 144 | // getLimiterForMethod gets the rate limiter for a specific method. | ||
| 145 | func (u *userLimiters) getLimiterForMethod(method string, config *Config, identifier string) *rate.Limiter { | ||
| 146 | u.mu.RLock() | ||
| 147 | limiter, ok := u.methodLimiters[method] | ||
| 148 | u.mu.RUnlock() | ||
| 149 | |||
| 150 | if ok { | ||
| 151 | return limiter | ||
| 152 | } | ||
| 153 | |||
| 154 | // Create new limiter for this method | ||
| 155 | u.mu.Lock() | ||
| 156 | defer u.mu.Unlock() | ||
| 157 | |||
| 158 | // Double-check after acquiring write lock | ||
| 159 | limiter, ok = u.methodLimiters[method] | ||
| 160 | if ok { | ||
| 161 | return limiter | ||
| 162 | } | ||
| 163 | |||
| 164 | // Get rate limit for this method and user | ||
| 165 | rps, burst := config.GetLimitForMethod(identifier, method) | ||
| 166 | |||
| 167 | // Create new rate limiter | ||
| 168 | limiter = rate.NewLimiter(rate.Limit(rps), burst) | ||
| 169 | u.methodLimiters[method] = limiter | ||
| 170 | |||
| 171 | return limiter | ||
| 172 | } | ||
| 173 | |||
| 174 | // updateLastAccess updates the last access time for this user. | ||
| 175 | func (u *userLimiters) updateLastAccess() { | ||
| 176 | u.mu.Lock() | ||
| 177 | u.lastAccess = time.Now() | ||
| 178 | u.mu.Unlock() | ||
| 179 | } | ||
| 180 | |||
| 181 | // isIdle returns true if this user hasn't been accessed recently. | ||
| 182 | func (u *userLimiters) isIdle(maxIdleTime time.Duration) bool { | ||
| 183 | u.mu.RLock() | ||
| 184 | defer u.mu.RUnlock() | ||
| 185 | return time.Since(u.lastAccess) > maxIdleTime | ||
| 186 | } | ||
| 187 | |||
| 188 | // cleanupLoop periodically removes idle limiters to free memory. | ||
| 189 | func (l *Limiter) cleanupLoop() { | ||
| 190 | ticker := time.NewTicker(l.config.CleanupInterval) | ||
| 191 | defer ticker.Stop() | ||
| 192 | |||
| 193 | for { | ||
| 194 | select { | ||
| 195 | case <-ticker.C: | ||
| 196 | l.cleanup() | ||
| 197 | case <-l.stopCleanup: | ||
| 198 | return | ||
| 199 | } | ||
| 200 | } | ||
| 201 | } | ||
| 202 | |||
| 203 | // cleanup removes idle limiters from memory. | ||
| 204 | func (l *Limiter) cleanup() { | ||
| 205 | l.mu.Lock() | ||
| 206 | defer l.mu.Unlock() | ||
| 207 | |||
| 208 | removed := 0 | ||
| 209 | |||
| 210 | for identifier, userLims := range l.limiters { | ||
| 211 | if userLims.isIdle(l.config.MaxIdleTime) { | ||
| 212 | delete(l.limiters, identifier) | ||
| 213 | removed++ | ||
| 214 | } | ||
| 215 | } | ||
| 216 | |||
| 217 | if removed > 0 { | ||
| 218 | l.stats.mu.Lock() | ||
| 219 | l.stats.ActiveLimiters -= int64(removed) | ||
| 220 | l.stats.mu.Unlock() | ||
| 221 | } | ||
| 222 | } | ||
| 223 | |||
| 224 | // Stop stops the cleanup goroutine. | ||
| 225 | func (l *Limiter) Stop() { | ||
| 226 | close(l.stopCleanup) | ||
| 227 | } | ||
| 228 | |||
| 229 | // Stats returns current rate limiter statistics. | ||
| 230 | func (l *Limiter) Stats() Stats { | ||
| 231 | l.stats.mu.RLock() | ||
| 232 | defer l.stats.mu.RUnlock() | ||
| 233 | |||
| 234 | // Update active limiters count | ||
| 235 | l.mu.RLock() | ||
| 236 | activeLimiters := int64(len(l.limiters)) | ||
| 237 | l.mu.RUnlock() | ||
| 238 | |||
| 239 | return Stats{ | ||
| 240 | ActiveLimiters: activeLimiters, | ||
| 241 | Allowed: l.stats.Allowed, | ||
| 242 | Denied: l.stats.Denied, | ||
| 243 | } | ||
| 244 | } | ||
| 245 | |||
| 246 | // incrementAllowed increments the allowed counter. | ||
| 247 | func (l *Limiter) incrementAllowed() { | ||
| 248 | l.stats.mu.Lock() | ||
| 249 | l.stats.Allowed++ | ||
| 250 | l.stats.mu.Unlock() | ||
| 251 | } | ||
| 252 | |||
| 253 | // incrementDenied increments the denied counter. | ||
| 254 | func (l *Limiter) incrementDenied() { | ||
| 255 | l.stats.mu.Lock() | ||
| 256 | l.stats.Denied++ | ||
| 257 | l.stats.mu.Unlock() | ||
| 258 | } | ||
| 259 | |||
| 260 | // incrementActiveLimiters increments the active limiters counter. | ||
| 261 | func (l *Limiter) incrementActiveLimiters() { | ||
| 262 | l.stats.mu.Lock() | ||
| 263 | l.stats.ActiveLimiters++ | ||
| 264 | l.stats.mu.Unlock() | ||
| 265 | } | ||
| 266 | |||
| 267 | // Reset clears all rate limiters and resets statistics. | ||
| 268 | // Useful for testing. | ||
| 269 | func (l *Limiter) Reset() { | ||
| 270 | l.mu.Lock() | ||
| 271 | l.limiters = make(map[string]*userLimiters) | ||
| 272 | l.mu.Unlock() | ||
| 273 | |||
| 274 | l.stats.mu.Lock() | ||
| 275 | l.stats.ActiveLimiters = 0 | ||
| 276 | l.stats.Allowed = 0 | ||
| 277 | l.stats.Denied = 0 | ||
| 278 | l.stats.mu.Unlock() | ||
| 279 | } | ||
