package ratelimit import ( "sync" "time" "golang.org/x/time/rate" ) // Limiter manages per-user rate limiting using the token bucket algorithm. type Limiter struct { config *Config // limiters maps identifier (pubkey or IP) to method-specific limiters limiters map[string]*userLimiters mu sync.RWMutex // stats tracks metrics stats Stats // cleanup manages periodic cleanup of idle limiters stopCleanup chan struct{} } // userLimiters holds rate limiters for a single user (pubkey or IP) type userLimiters struct { // default limiter for methods without specific limits defaultLimiter *rate.Limiter // method-specific limiters methodLimiters map[string]*rate.Limiter // last access time for cleanup lastAccess time.Time mu sync.RWMutex } // Stats holds rate limiter statistics. type Stats struct { ActiveLimiters int64 // Number of active users being tracked Allowed int64 // Total requests allowed Denied int64 // Total requests denied mu sync.RWMutex } // DenialRate returns the percentage of requests denied. func (s *Stats) DenialRate() float64 { s.mu.RLock() defer s.mu.RUnlock() total := s.Allowed + s.Denied if total == 0 { return 0 } return float64(s.Denied) / float64(total) * 100 } // New creates a new rate limiter with the given configuration. func New(config *Config) *Limiter { if config == nil { config = DefaultConfig() } config.Validate() l := &Limiter{ config: config, limiters: make(map[string]*userLimiters), stopCleanup: make(chan struct{}), } // Start cleanup goroutine go l.cleanupLoop() return l } // Allow checks if a request should be allowed for the given identifier and method. // identifier is either a pubkey (for authenticated users) or IP address. // method is the full gRPC method name. func (l *Limiter) Allow(identifier, method string) bool { // Check if method should be skipped if l.config.ShouldSkipMethod(method) { l.incrementAllowed() return true } // Check if user should be skipped if l.config.ShouldSkipUser(identifier) { l.incrementAllowed() return true } // Get or create user limiters userLims := l.getUserLimiters(identifier) // Get method-specific limiter limiter := userLims.getLimiterForMethod(method, l.config, identifier) // Check if request is allowed if limiter.Allow() { l.incrementAllowed() return true } l.incrementDenied() return false } // getUserLimiters gets or creates the limiters for a user. func (l *Limiter) getUserLimiters(identifier string) *userLimiters { // Try read lock first (fast path) l.mu.RLock() userLims, ok := l.limiters[identifier] l.mu.RUnlock() if ok { userLims.updateLastAccess() return userLims } // Need to create new limiters (slow path) l.mu.Lock() defer l.mu.Unlock() // Double-check after acquiring write lock userLims, ok = l.limiters[identifier] if ok { userLims.updateLastAccess() return userLims } // Create new user limiters userLims = &userLimiters{ methodLimiters: make(map[string]*rate.Limiter), lastAccess: time.Now(), } l.limiters[identifier] = userLims l.incrementActiveLimiters() return userLims } // getLimiterForMethod gets the rate limiter for a specific method. func (u *userLimiters) getLimiterForMethod(method string, config *Config, identifier string) *rate.Limiter { u.mu.RLock() limiter, ok := u.methodLimiters[method] u.mu.RUnlock() if ok { return limiter } // Create new limiter for this method u.mu.Lock() defer u.mu.Unlock() // Double-check after acquiring write lock limiter, ok = u.methodLimiters[method] if ok { return limiter } // Get rate limit for this method and user rps, burst := config.GetLimitForMethod(identifier, method) // Create new rate limiter limiter = rate.NewLimiter(rate.Limit(rps), burst) u.methodLimiters[method] = limiter return limiter } // updateLastAccess updates the last access time for this user. func (u *userLimiters) updateLastAccess() { u.mu.Lock() u.lastAccess = time.Now() u.mu.Unlock() } // isIdle returns true if this user hasn't been accessed recently. func (u *userLimiters) isIdle(maxIdleTime time.Duration) bool { u.mu.RLock() defer u.mu.RUnlock() return time.Since(u.lastAccess) > maxIdleTime } // cleanupLoop periodically removes idle limiters to free memory. func (l *Limiter) cleanupLoop() { ticker := time.NewTicker(l.config.CleanupInterval) defer ticker.Stop() for { select { case <-ticker.C: l.cleanup() case <-l.stopCleanup: return } } } // cleanup removes idle limiters from memory. func (l *Limiter) cleanup() { l.mu.Lock() defer l.mu.Unlock() removed := 0 for identifier, userLims := range l.limiters { if userLims.isIdle(l.config.MaxIdleTime) { delete(l.limiters, identifier) removed++ } } if removed > 0 { l.stats.mu.Lock() l.stats.ActiveLimiters -= int64(removed) l.stats.mu.Unlock() } } // Stop stops the cleanup goroutine. func (l *Limiter) Stop() { close(l.stopCleanup) } // Stats returns current rate limiter statistics. func (l *Limiter) Stats() Stats { l.stats.mu.RLock() defer l.stats.mu.RUnlock() // Update active limiters count l.mu.RLock() activeLimiters := int64(len(l.limiters)) l.mu.RUnlock() return Stats{ ActiveLimiters: activeLimiters, Allowed: l.stats.Allowed, Denied: l.stats.Denied, } } // incrementAllowed increments the allowed counter. func (l *Limiter) incrementAllowed() { l.stats.mu.Lock() l.stats.Allowed++ l.stats.mu.Unlock() } // incrementDenied increments the denied counter. func (l *Limiter) incrementDenied() { l.stats.mu.Lock() l.stats.Denied++ l.stats.mu.Unlock() } // incrementActiveLimiters increments the active limiters counter. func (l *Limiter) incrementActiveLimiters() { l.stats.mu.Lock() l.stats.ActiveLimiters++ l.stats.mu.Unlock() } // Reset clears all rate limiters and resets statistics. // Useful for testing. func (l *Limiter) Reset() { l.mu.Lock() l.limiters = make(map[string]*userLimiters) l.mu.Unlock() l.stats.mu.Lock() l.stats.ActiveLimiters = 0 l.stats.Allowed = 0 l.stats.Denied = 0 l.stats.mu.Unlock() }