summaryrefslogtreecommitdiffstats
path: root/internal/metrics/interceptor.go
blob: 02eb69d1340b0df3fee83443cd41de47ed680f8e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
package metrics

import (
	"context"
	"time"

	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

// UnaryServerInterceptor creates a gRPC unary interceptor for metrics collection.
// It should be the first interceptor in the chain to measure total request time.
func UnaryServerInterceptor(m *Metrics) grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		start := time.Now()

		// Call the handler
		resp, err := handler(ctx, req)

		// Record metrics
		duration := time.Since(start).Seconds()
		requestStatus := getRequestStatus(err)
		m.RecordRequest(info.FullMethod, string(requestStatus), duration)

		return resp, err
	}
}

// StreamServerInterceptor creates a gRPC stream interceptor for metrics collection.
func StreamServerInterceptor(m *Metrics) grpc.StreamServerInterceptor {
	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		start := time.Now()

		// Increment subscriptions count
		m.IncrementSubscriptions()
		defer m.DecrementSubscriptions()

		// Call the handler
		err := handler(srv, ss)

		// Record metrics
		duration := time.Since(start).Seconds()
		requestStatus := getRequestStatus(err)
		m.RecordRequest(info.FullMethod, string(requestStatus), duration)

		return err
	}
}

// getRequestStatus determines the request status from an error.
func getRequestStatus(err error) RequestStatus {
	if err == nil {
		return StatusOK
	}

	st, ok := status.FromError(err)
	if !ok {
		return StatusError
	}

	switch st.Code() {
	case codes.OK:
		return StatusOK
	case codes.Unauthenticated:
		return StatusUnauthenticated
	case codes.ResourceExhausted:
		return StatusRateLimited
	case codes.InvalidArgument:
		return StatusInvalidRequest
	default:
		return StatusError
	}
}