package ratelimit import ( "context" "fmt" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "northwest.io/muxstr/internal/auth" ) // UnaryInterceptor creates a gRPC unary interceptor for rate limiting. // It should be chained after the auth interceptor so pubkey is available. func UnaryInterceptor(limiter *Limiter) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { // Get identifier (pubkey or IP) identifier := getIdentifier(ctx) // Check rate limit if !limiter.Allow(identifier, info.FullMethod) { return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded for %s", identifier) } return handler(ctx, req) } } // StreamInterceptor creates a gRPC stream interceptor for rate limiting. // It should be chained after the auth interceptor so pubkey is available. func StreamInterceptor(limiter *Limiter) grpc.StreamServerInterceptor { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { // Get identifier (pubkey or IP) identifier := getIdentifier(ss.Context()) // Check rate limit if !limiter.Allow(identifier, info.FullMethod) { return status.Errorf(codes.ResourceExhausted, "rate limit exceeded for %s", identifier) } return handler(srv, ss) } } // getIdentifier extracts the user identifier from the context. // Returns pubkey if authenticated, otherwise returns IP address. func getIdentifier(ctx context.Context) string { // Try to get authenticated pubkey first pubkey, ok := auth.PubkeyFromContext(ctx) if ok && pubkey != "" { return pubkey } // Fall back to IP address return getIPAddress(ctx) } // getIPAddress extracts the client IP address from the context. func getIPAddress(ctx context.Context) string { // Try to get from peer info p, ok := peer.FromContext(ctx) if ok && p.Addr != nil { return p.Addr.String() } // Try to get from metadata (X-Forwarded-For header) md, ok := metadata.FromIncomingContext(ctx) if ok { if xff := md.Get("x-forwarded-for"); len(xff) > 0 { return xff[0] } if xri := md.Get("x-real-ip"); len(xri) > 0 { return xri[0] } } return "unknown" } // WaitUnaryInterceptor is a variant that waits instead of rejecting when rate limited. // Use this for critical operations that should never fail due to rate limiting. // WARNING: This can cause requests to hang if rate limit is never freed. func WaitUnaryInterceptor(limiter *Limiter) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { identifier := getIdentifier(ctx) // Get user limiters userLims := limiter.getUserLimiters(identifier) rateLimiter := userLims.getLimiterForMethod(info.FullMethod, limiter.config, identifier) // Wait for permission (respects context deadline) if err := rateLimiter.Wait(ctx); err != nil { return nil, status.Errorf(codes.DeadlineExceeded, "rate limit wait cancelled: %v", err) } limiter.incrementAllowed() return handler(ctx, req) } } // RetryableError wraps a rate limit error with retry-after information. type RetryableError struct { *status.Status RetryAfter float64 // seconds } // Error implements the error interface. func (e *RetryableError) Error() string { return fmt.Sprintf("%s (retry after %.1fs)", e.Status.Message(), e.RetryAfter) } // UnaryInterceptorWithRetryAfter is like UnaryInterceptor but includes retry-after info. // Clients can extract this to implement smart backoff. func UnaryInterceptorWithRetryAfter(limiter *Limiter) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { identifier := getIdentifier(ctx) // Get user limiters userLims := limiter.getUserLimiters(identifier) rateLimiter := userLims.getLimiterForMethod(info.FullMethod, limiter.config, identifier) // Get reservation to check how long to wait reservation := rateLimiter.Reserve() if !reservation.OK() { return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded (burst exhausted)") } delay := reservation.Delay() if delay > 0 { // Cancel the reservation since we're not going to wait reservation.Cancel() limiter.incrementDenied() // Return error with retry-after information st := status.New(codes.ResourceExhausted, fmt.Sprintf("rate limit exceeded for %s", identifier)) return nil, &RetryableError{ Status: st, RetryAfter: delay.Seconds(), } } // No delay needed, proceed limiter.incrementAllowed() return handler(ctx, req) } }