package ratelimit import ( "context" "testing" "time" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) func TestBasicRateLimit(t *testing.T) { config := &Config{ RequestsPerSecond: 10, BurstSize: 10, } limiter := New(config) defer limiter.Stop() identifier := "test-user" method := "/test.Service/Method" // First 10 requests should succeed (burst) for i := 0; i < 10; i++ { if !limiter.Allow(identifier, method) { t.Errorf("request %d should be allowed", i) } } // 11th request should be denied (burst exhausted) if limiter.Allow(identifier, method) { t.Error("request 11 should be denied") } // Wait for tokens to refill time.Sleep(150 * time.Millisecond) // Should allow 1 more request (1 token refilled) if !limiter.Allow(identifier, method) { t.Error("request after refill should be allowed") } } func TestPerUserLimits(t *testing.T) { config := &Config{ RequestsPerSecond: 10, BurstSize: 10, } limiter := New(config) defer limiter.Stop() method := "/test.Service/Method" // Different users should have independent limits user1 := "user1" user2 := "user2" // Exhaust user1's quota for i := 0; i < 10; i++ { limiter.Allow(user1, method) } // User1 should be denied if limiter.Allow(user1, method) { t.Error("user1 should be rate limited") } // User2 should still be allowed if !limiter.Allow(user2, method) { t.Error("user2 should not be rate limited") } } func TestMethodSpecificLimits(t *testing.T) { config := &Config{ RequestsPerSecond: 10, BurstSize: 10, MethodLimits: map[string]MethodLimit{ "/test.Service/StrictMethod": { RequestsPerSecond: 2, BurstSize: 2, }, }, } limiter := New(config) defer limiter.Stop() identifier := "test-user" // Regular method should allow 10 requests regularMethod := "/test.Service/RegularMethod" for i := 0; i < 10; i++ { if !limiter.Allow(identifier, regularMethod) { t.Errorf("regular method request %d should be allowed", i) } } // Strict method should only allow 2 requests strictMethod := "/test.Service/StrictMethod" for i := 0; i < 2; i++ { if !limiter.Allow(identifier, strictMethod) { t.Errorf("strict method request %d should be allowed", i) } } // 3rd request should be denied if limiter.Allow(identifier, strictMethod) { t.Error("strict method request 3 should be denied") } } func TestUserSpecificLimits(t *testing.T) { config := &Config{ RequestsPerSecond: 10, BurstSize: 10, UserLimits: map[string]UserLimit{ "vip-user": { RequestsPerSecond: 100, BurstSize: 100, }, }, } limiter := New(config) defer limiter.Stop() method := "/test.Service/Method" // Regular user should be limited to 10 regularUser := "regular-user" for i := 0; i < 10; i++ { limiter.Allow(regularUser, method) } if limiter.Allow(regularUser, method) { t.Error("regular user should be rate limited") } // VIP user should allow 100 vipUser := "vip-user" for i := 0; i < 100; i++ { if !limiter.Allow(vipUser, method) { t.Errorf("vip user request %d should be allowed", i) } } } func TestSkipMethods(t *testing.T) { config := &Config{ RequestsPerSecond: 1, BurstSize: 1, SkipMethods: []string{ "/health/Check", }, } limiter := New(config) defer limiter.Stop() identifier := "test-user" // Regular method should be rate limited regularMethod := "/test.Service/Method" limiter.Allow(identifier, regularMethod) if limiter.Allow(identifier, regularMethod) { t.Error("regular method should be rate limited") } // Skipped method should never be rate limited skipMethod := "/health/Check" for i := 0; i < 100; i++ { if !limiter.Allow(identifier, skipMethod) { t.Error("skipped method should never be rate limited") } } } func TestSkipUsers(t *testing.T) { config := &Config{ RequestsPerSecond: 1, BurstSize: 1, SkipUsers: []string{ "admin-user", }, } limiter := New(config) defer limiter.Stop() method := "/test.Service/Method" // Regular user should be rate limited regularUser := "regular-user" limiter.Allow(regularUser, method) if limiter.Allow(regularUser, method) { t.Error("regular user should be rate limited") } // Admin user should never be rate limited adminUser := "admin-user" for i := 0; i < 100; i++ { if !limiter.Allow(adminUser, method) { t.Error("admin user should never be rate limited") } } } func TestStats(t *testing.T) { config := &Config{ RequestsPerSecond: 10, BurstSize: 5, } limiter := New(config) defer limiter.Stop() identifier := "test-user" method := "/test.Service/Method" // Make some requests for i := 0; i < 5; i++ { limiter.Allow(identifier, method) // All allowed (within burst) } for i := 0; i < 3; i++ { limiter.Allow(identifier, method) // All denied (burst exhausted) } stats := limiter.Stats() if stats.Allowed != 5 { t.Errorf("expected 5 allowed, got %d", stats.Allowed) } if stats.Denied != 3 { t.Errorf("expected 3 denied, got %d", stats.Denied) } if stats.ActiveLimiters != 1 { t.Errorf("expected 1 active limiter, got %d", stats.ActiveLimiters) } expectedDenialRate := 37.5 // 3/8 * 100 if stats.DenialRate() != expectedDenialRate { t.Errorf("expected denial rate %.1f%%, got %.1f%%", expectedDenialRate, stats.DenialRate()) } } func TestCleanup(t *testing.T) { config := &Config{ RequestsPerSecond: 10, BurstSize: 10, CleanupInterval: 100 * time.Millisecond, MaxIdleTime: 200 * time.Millisecond, } limiter := New(config) defer limiter.Stop() // Create limiters for multiple users for i := 0; i < 5; i++ { limiter.Allow("user-"+string(rune('0'+i)), "/test") } stats := limiter.Stats() if stats.ActiveLimiters != 5 { t.Errorf("expected 5 active limiters, got %d", stats.ActiveLimiters) } // Wait for cleanup to run time.Sleep(350 * time.Millisecond) stats = limiter.Stats() if stats.ActiveLimiters != 0 { t.Errorf("expected 0 active limiters after cleanup, got %d", stats.ActiveLimiters) } } func TestUnaryInterceptor(t *testing.T) { config := &Config{ RequestsPerSecond: 2, BurstSize: 2, } limiter := New(config) defer limiter.Stop() interceptor := UnaryInterceptor(limiter) // Create a test handler handler := func(ctx context.Context, req interface{}) (interface{}, error) { return "success", nil } info := &grpc.UnaryServerInfo{ FullMethod: "/test.Service/Method", } // Create context with metadata (simulating IP) md := metadata.Pairs("x-real-ip", "192.168.1.1") ctx := metadata.NewIncomingContext(context.Background(), md) // First 2 requests should succeed for i := 0; i < 2; i++ { _, err := interceptor(ctx, nil, info, handler) if err != nil { t.Errorf("request %d should succeed, got error: %v", i, err) } } // 3rd request should be rate limited _, err := interceptor(ctx, nil, info, handler) if err == nil { t.Error("expected rate limit error") } st, ok := status.FromError(err) if !ok { t.Error("expected gRPC status error") } if st.Code() != codes.ResourceExhausted { t.Errorf("expected ResourceExhausted, got %v", st.Code()) } } func TestGetLimitForMethod(t *testing.T) { config := &Config{ RequestsPerSecond: 10, BurstSize: 20, MethodLimits: map[string]MethodLimit{ "/test/Method1": { RequestsPerSecond: 5, BurstSize: 10, }, }, UserLimits: map[string]UserLimit{ "vip-user": { RequestsPerSecond: 50, BurstSize: 100, MethodLimits: map[string]MethodLimit{ "/test/Method1": { RequestsPerSecond: 25, BurstSize: 50, }, }, }, }, } tests := []struct { name string pubkey string method string expectedRPS float64 expectedBurst int }{ { name: "default for regular user", pubkey: "regular-user", method: "/test/Method2", expectedRPS: 10, expectedBurst: 20, }, { name: "method limit for regular user", pubkey: "regular-user", method: "/test/Method1", expectedRPS: 5, expectedBurst: 10, }, { name: "user limit default method", pubkey: "vip-user", method: "/test/Method2", expectedRPS: 50, expectedBurst: 100, }, { name: "user method limit (highest precedence)", pubkey: "vip-user", method: "/test/Method1", expectedRPS: 25, expectedBurst: 50, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rps, burst := config.GetLimitForMethod(tt.pubkey, tt.method) if rps != tt.expectedRPS { t.Errorf("expected RPS %.1f, got %.1f", tt.expectedRPS, rps) } if burst != tt.expectedBurst { t.Errorf("expected burst %d, got %d", tt.expectedBurst, burst) } }) } } func BenchmarkRateLimitAllow(b *testing.B) { config := &Config{ RequestsPerSecond: 1000, BurstSize: 1000, } limiter := New(config) defer limiter.Stop() identifier := "bench-user" method := "/test.Service/Method" b.ResetTimer() for i := 0; i < b.N; i++ { limiter.Allow(identifier, method) } } func BenchmarkRateLimitDeny(b *testing.B) { config := &Config{ RequestsPerSecond: 1, BurstSize: 1, } limiter := New(config) defer limiter.Stop() identifier := "bench-user" method := "/test.Service/Method" // Exhaust quota limiter.Allow(identifier, method) b.ResetTimer() for i := 0; i < b.N; i++ { limiter.Allow(identifier, method) } }