summaryrefslogtreecommitdiffstats
path: root/internal/metrics/interceptor.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/metrics/interceptor.go')
-rw-r--r--internal/metrics/interceptor.go74
1 files changed, 74 insertions, 0 deletions
diff --git a/internal/metrics/interceptor.go b/internal/metrics/interceptor.go
new file mode 100644
index 0000000..02eb69d
--- /dev/null
+++ b/internal/metrics/interceptor.go
@@ -0,0 +1,74 @@
1package metrics
2
3import (
4 "context"
5 "time"
6
7 "google.golang.org/grpc"
8 "google.golang.org/grpc/codes"
9 "google.golang.org/grpc/status"
10)
11
12// UnaryServerInterceptor creates a gRPC unary interceptor for metrics collection.
13// It should be the first interceptor in the chain to measure total request time.
14func UnaryServerInterceptor(m *Metrics) grpc.UnaryServerInterceptor {
15 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
16 start := time.Now()
17
18 // Call the handler
19 resp, err := handler(ctx, req)
20
21 // Record metrics
22 duration := time.Since(start).Seconds()
23 requestStatus := getRequestStatus(err)
24 m.RecordRequest(info.FullMethod, string(requestStatus), duration)
25
26 return resp, err
27 }
28}
29
30// StreamServerInterceptor creates a gRPC stream interceptor for metrics collection.
31func StreamServerInterceptor(m *Metrics) grpc.StreamServerInterceptor {
32 return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
33 start := time.Now()
34
35 // Increment subscriptions count
36 m.IncrementSubscriptions()
37 defer m.DecrementSubscriptions()
38
39 // Call the handler
40 err := handler(srv, ss)
41
42 // Record metrics
43 duration := time.Since(start).Seconds()
44 requestStatus := getRequestStatus(err)
45 m.RecordRequest(info.FullMethod, string(requestStatus), duration)
46
47 return err
48 }
49}
50
51// getRequestStatus determines the request status from an error.
52func getRequestStatus(err error) RequestStatus {
53 if err == nil {
54 return StatusOK
55 }
56
57 st, ok := status.FromError(err)
58 if !ok {
59 return StatusError
60 }
61
62 switch st.Code() {
63 case codes.OK:
64 return StatusOK
65 case codes.Unauthenticated:
66 return StatusUnauthenticated
67 case codes.ResourceExhausted:
68 return StatusRateLimited
69 case codes.InvalidArgument:
70 return StatusInvalidRequest
71 default:
72 return StatusError
73 }
74}