package metrics import ( "context" "time" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) // UnaryServerInterceptor creates a gRPC unary interceptor for metrics collection. // It should be the first interceptor in the chain to measure total request time. func UnaryServerInterceptor(m *Metrics) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { start := time.Now() // Call the handler resp, err := handler(ctx, req) // Record metrics duration := time.Since(start).Seconds() requestStatus := getRequestStatus(err) m.RecordRequest(info.FullMethod, string(requestStatus), duration) return resp, err } } // StreamServerInterceptor creates a gRPC stream interceptor for metrics collection. func StreamServerInterceptor(m *Metrics) grpc.StreamServerInterceptor { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { start := time.Now() // Increment subscriptions count m.IncrementSubscriptions() defer m.DecrementSubscriptions() // Call the handler err := handler(srv, ss) // Record metrics duration := time.Since(start).Seconds() requestStatus := getRequestStatus(err) m.RecordRequest(info.FullMethod, string(requestStatus), duration) return err } } // getRequestStatus determines the request status from an error. func getRequestStatus(err error) RequestStatus { if err == nil { return StatusOK } st, ok := status.FromError(err) if !ok { return StatusError } switch st.Code() { case codes.OK: return StatusOK case codes.Unauthenticated: return StatusUnauthenticated case codes.ResourceExhausted: return StatusRateLimited case codes.InvalidArgument: return StatusInvalidRequest default: return StatusError } }