diff options
| author | bndw <ben@bdw.to> | 2026-03-09 08:01:02 -0700 |
|---|---|---|
| committer | bndw <ben@bdw.to> | 2026-03-09 08:01:02 -0700 |
| commit | 61a85baf87d89fcc09f9469a113a2ddc982b0a24 (patch) | |
| tree | d8359ce5cbcbb9402ba92c617c4ebd702adf33e9 /relay/handler.go | |
| parent | ce684848e25fed3aabdde4ffba6d2d8c40afa030 (diff) | |
feat: phase 2 relay implementation
Implement the Axon relay as relay/ (module axon/relay). Includes:
- WebSocket framing (RFC 6455, no external deps) in relay/websocket/
- Per-connection auth: challenge/response with ed25519 + allowlist check
- Ingest pipeline: sig verify, dedup, ephemeral fanout, SQLite persistence
- Subscription manager with prefix-matching filter fanout in relay/subscription/
- SQLite storage with WAL/cache config and UNION query builder in relay/storage/
- Graceful shutdown on SIGINT/SIGTERM
- Filter/TagFilter types added to axon core package (required by relay)
Diffstat (limited to 'relay/handler.go')
| -rw-r--r-- | relay/handler.go | 411 |
1 files changed, 411 insertions, 0 deletions
diff --git a/relay/handler.go b/relay/handler.go new file mode 100644 index 0000000..afd9622 --- /dev/null +++ b/relay/handler.go | |||
| @@ -0,0 +1,411 @@ | |||
| 1 | package main | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "bytes" | ||
| 5 | "context" | ||
| 6 | "fmt" | ||
| 7 | "log" | ||
| 8 | "time" | ||
| 9 | |||
| 10 | "github.com/vmihailenco/msgpack/v5" | ||
| 11 | |||
| 12 | "axon" | ||
| 13 | "axon/relay/storage" | ||
| 14 | "axon/relay/subscription" | ||
| 15 | ws "axon/relay/websocket" | ||
| 16 | ) | ||
| 17 | |||
| 18 | // Wire protocol message type constants. | ||
| 19 | const ( | ||
| 20 | // Client → Relay | ||
| 21 | MsgTypeAuth uint16 = 1 | ||
| 22 | MsgTypeSubscribe uint16 = 2 | ||
| 23 | MsgTypeUnsubscribe uint16 = 3 | ||
| 24 | MsgTypePublish uint16 = 4 | ||
| 25 | |||
| 26 | // Relay → Client | ||
| 27 | MsgTypeChallenge uint16 = 10 | ||
| 28 | MsgTypeEvent uint16 = 11 | ||
| 29 | MsgTypeEose uint16 = 12 | ||
| 30 | MsgTypeOk uint16 = 13 | ||
| 31 | MsgTypeError uint16 = 14 | ||
| 32 | ) | ||
| 33 | |||
| 34 | // Payload types — Relay → Client | ||
| 35 | |||
| 36 | type ChallengePayload struct { | ||
| 37 | Nonce []byte `msgpack:"nonce"` | ||
| 38 | } | ||
| 39 | |||
| 40 | type EventPayload struct { | ||
| 41 | SubID string `msgpack:"sub_id"` | ||
| 42 | Event axon.Event `msgpack:"event"` | ||
| 43 | } | ||
| 44 | |||
| 45 | type EosePayload struct { | ||
| 46 | SubID string `msgpack:"sub_id"` | ||
| 47 | } | ||
| 48 | |||
| 49 | type OkPayload struct { | ||
| 50 | Message string `msgpack:"message"` | ||
| 51 | } | ||
| 52 | |||
| 53 | type ErrorPayload struct { | ||
| 54 | Code uint16 `msgpack:"code"` | ||
| 55 | Message string `msgpack:"message"` | ||
| 56 | } | ||
| 57 | |||
| 58 | // Payload types — Client → Relay | ||
| 59 | |||
| 60 | type AuthPayload struct { | ||
| 61 | PubKey []byte `msgpack:"pubkey"` | ||
| 62 | Sig []byte `msgpack:"sig"` | ||
| 63 | } | ||
| 64 | |||
| 65 | type SubscribePayload struct { | ||
| 66 | SubID string `msgpack:"sub_id"` | ||
| 67 | Filter axon.Filter `msgpack:"filter"` | ||
| 68 | } | ||
| 69 | |||
| 70 | type UnsubscribePayload struct { | ||
| 71 | SubID string `msgpack:"sub_id"` | ||
| 72 | } | ||
| 73 | |||
| 74 | type PublishPayload struct { | ||
| 75 | Event axon.Event `msgpack:"event"` | ||
| 76 | } | ||
| 77 | |||
| 78 | // conn holds per-connection state. | ||
| 79 | type conn struct { | ||
| 80 | id string // unique connection ID (hex nonce for logging) | ||
| 81 | ws *ws.Conn | ||
| 82 | store *storage.Storage | ||
| 83 | global *subscription.GlobalManager | ||
| 84 | allowlist [][]byte | ||
| 85 | relayURL string | ||
| 86 | |||
| 87 | authed bool | ||
| 88 | pubkey []byte | ||
| 89 | nonce []byte | ||
| 90 | |||
| 91 | mgr *subscription.Manager | ||
| 92 | } | ||
| 93 | |||
| 94 | // send encodes and writes a wire message to the client. | ||
| 95 | func (c *conn) send(msgType uint16, payload interface{}) error { | ||
| 96 | b, err := msgpack.Marshal([]interface{}{msgType, payload}) | ||
| 97 | if err != nil { | ||
| 98 | return fmt.Errorf("handler: marshal msg %d: %w", msgType, err) | ||
| 99 | } | ||
| 100 | return c.ws.Write(b) | ||
| 101 | } | ||
| 102 | |||
| 103 | // sendError sends an ErrorPayload. If fatal is true the connection is then closed. | ||
| 104 | func (c *conn) sendError(code uint16, message string, fatal bool) { | ||
| 105 | _ = c.send(MsgTypeError, ErrorPayload{Code: code, Message: message}) | ||
| 106 | if fatal { | ||
| 107 | c.ws.Close(1000, "") | ||
| 108 | } | ||
| 109 | } | ||
| 110 | |||
| 111 | // serve is the main per-connection loop. | ||
| 112 | func (c *conn) serve(ctx context.Context) { | ||
| 113 | defer func() { | ||
| 114 | c.mgr.CloseAll() | ||
| 115 | c.global.UnregisterConn(c.id) | ||
| 116 | }() | ||
| 117 | |||
| 118 | // Send challenge immediately on connect. | ||
| 119 | if err := c.send(MsgTypeChallenge, ChallengePayload{Nonce: c.nonce}); err != nil { | ||
| 120 | log.Printf("conn %s: send challenge: %v", c.id, err) | ||
| 121 | return | ||
| 122 | } | ||
| 123 | |||
| 124 | // Start ping goroutine. | ||
| 125 | pingStop := make(chan struct{}) | ||
| 126 | defer close(pingStop) | ||
| 127 | go c.pingLoop(pingStop) | ||
| 128 | |||
| 129 | for { | ||
| 130 | raw, err := c.ws.Read(ctx) | ||
| 131 | if err != nil { | ||
| 132 | if ctx.Err() != nil { | ||
| 133 | return | ||
| 134 | } | ||
| 135 | // Connection closed or read error. | ||
| 136 | return | ||
| 137 | } | ||
| 138 | |||
| 139 | if err := c.dispatch(ctx, raw); err != nil { | ||
| 140 | log.Printf("conn %s: dispatch: %v", c.id, err) | ||
| 141 | return | ||
| 142 | } | ||
| 143 | } | ||
| 144 | } | ||
| 145 | |||
| 146 | // pingLoop sends a WebSocket ping every 30 seconds. If two consecutive pings | ||
| 147 | // go unanswered (no pong received within 30s each) the connection is closed. | ||
| 148 | func (c *conn) pingLoop(stop <-chan struct{}) { | ||
| 149 | ticker := time.NewTicker(30 * time.Second) | ||
| 150 | defer ticker.Stop() | ||
| 151 | missed := 0 | ||
| 152 | for { | ||
| 153 | select { | ||
| 154 | case <-stop: | ||
| 155 | return | ||
| 156 | case <-ticker.C: | ||
| 157 | if err := c.ws.Ping(); err != nil { | ||
| 158 | return | ||
| 159 | } | ||
| 160 | missed++ | ||
| 161 | if missed >= 2 { | ||
| 162 | log.Printf("conn %s: missed 2 pings, closing", c.id) | ||
| 163 | c.ws.Close(1001, "ping timeout") | ||
| 164 | return | ||
| 165 | } | ||
| 166 | } | ||
| 167 | } | ||
| 168 | } | ||
| 169 | |||
| 170 | // dispatch parses a raw message and routes it to the appropriate handler. | ||
| 171 | func (c *conn) dispatch(ctx context.Context, raw []byte) error { | ||
| 172 | // Decode as [uint16, rawPayload] | ||
| 173 | var arr []msgpack.RawMessage | ||
| 174 | if err := msgpack.Unmarshal(raw, &arr); err != nil { | ||
| 175 | c.sendError(400, "malformed message", false) | ||
| 176 | return nil | ||
| 177 | } | ||
| 178 | if len(arr) < 2 { | ||
| 179 | c.sendError(400, "message too short", false) | ||
| 180 | return nil | ||
| 181 | } | ||
| 182 | |||
| 183 | var msgType uint16 | ||
| 184 | if err := msgpack.Unmarshal(arr[0], &msgType); err != nil { | ||
| 185 | c.sendError(400, "invalid message type", false) | ||
| 186 | return nil | ||
| 187 | } | ||
| 188 | |||
| 189 | // Only MsgTypeAuth is allowed before authentication. | ||
| 190 | if !c.authed && msgType != MsgTypeAuth { | ||
| 191 | c.sendError(401, "not authenticated", true) | ||
| 192 | return fmt.Errorf("unauthenticated message type %d", msgType) | ||
| 193 | } | ||
| 194 | |||
| 195 | switch msgType { | ||
| 196 | case MsgTypeAuth: | ||
| 197 | return c.handleAuth(arr[1]) | ||
| 198 | case MsgTypePublish: | ||
| 199 | return c.handlePublish(ctx, arr[1]) | ||
| 200 | case MsgTypeSubscribe: | ||
| 201 | return c.handleSubscribe(ctx, arr[1]) | ||
| 202 | case MsgTypeUnsubscribe: | ||
| 203 | return c.handleUnsubscribe(arr[1]) | ||
| 204 | default: | ||
| 205 | c.sendError(400, fmt.Sprintf("unknown message type %d", msgType), false) | ||
| 206 | } | ||
| 207 | return nil | ||
| 208 | } | ||
| 209 | |||
| 210 | // handleAuth processes an Auth message. | ||
| 211 | func (c *conn) handleAuth(raw msgpack.RawMessage) error { | ||
| 212 | var p AuthPayload | ||
| 213 | if err := msgpack.Unmarshal(raw, &p); err != nil { | ||
| 214 | c.sendError(400, "malformed auth payload", true) | ||
| 215 | return fmt.Errorf("unmarshal auth: %w", err) | ||
| 216 | } | ||
| 217 | |||
| 218 | if len(p.PubKey) != 32 { | ||
| 219 | c.sendError(400, "pubkey must be 32 bytes", true) | ||
| 220 | return fmt.Errorf("bad pubkey length %d", len(p.PubKey)) | ||
| 221 | } | ||
| 222 | |||
| 223 | if !axon.VerifyChallenge(p.PubKey, c.nonce, c.relayURL, p.Sig) { | ||
| 224 | c.sendError(401, "invalid signature", true) | ||
| 225 | return fmt.Errorf("auth signature invalid") | ||
| 226 | } | ||
| 227 | |||
| 228 | // Check allowlist. | ||
| 229 | if len(c.allowlist) > 0 { | ||
| 230 | allowed := false | ||
| 231 | for _, pk := range c.allowlist { | ||
| 232 | if bytes.Equal(pk, p.PubKey) { | ||
| 233 | allowed = true | ||
| 234 | break | ||
| 235 | } | ||
| 236 | } | ||
| 237 | if !allowed { | ||
| 238 | c.sendError(403, "pubkey not in allowlist", true) | ||
| 239 | return fmt.Errorf("pubkey not in allowlist") | ||
| 240 | } | ||
| 241 | } | ||
| 242 | |||
| 243 | c.authed = true | ||
| 244 | c.pubkey = p.PubKey | ||
| 245 | return c.send(MsgTypeOk, OkPayload{Message: "authenticated"}) | ||
| 246 | } | ||
| 247 | |||
| 248 | // handlePublish processes a Publish message. | ||
| 249 | func (c *conn) handlePublish(ctx context.Context, raw msgpack.RawMessage) error { | ||
| 250 | var p PublishPayload | ||
| 251 | if err := msgpack.Unmarshal(raw, &p); err != nil { | ||
| 252 | c.sendError(400, "malformed publish payload", false) | ||
| 253 | return nil | ||
| 254 | } | ||
| 255 | |||
| 256 | event := &p.Event | ||
| 257 | |||
| 258 | // Content length check. | ||
| 259 | if len(event.Content) > 65536 { | ||
| 260 | c.sendError(413, "content exceeds 64KB limit", false) | ||
| 261 | return nil | ||
| 262 | } | ||
| 263 | |||
| 264 | // Signature verification. | ||
| 265 | if err := axon.Verify(event); err != nil { | ||
| 266 | c.sendError(400, fmt.Sprintf("invalid event: %v", err), false) | ||
| 267 | return nil | ||
| 268 | } | ||
| 269 | |||
| 270 | // Job request expiry check (kinds 5000–5999). | ||
| 271 | if event.Kind >= 5000 && event.Kind < 6000 { | ||
| 272 | if expired, err := isExpired(event); err != nil { | ||
| 273 | c.sendError(400, fmt.Sprintf("bad expires_at tag: %v", err), false) | ||
| 274 | return nil | ||
| 275 | } else if expired { | ||
| 276 | c.sendError(400, "job request has expired", false) | ||
| 277 | return nil | ||
| 278 | } | ||
| 279 | } | ||
| 280 | |||
| 281 | // Marshal envelope bytes for storage and fanout. | ||
| 282 | envelopeBytes, err := axon.MarshalEvent(event) | ||
| 283 | if err != nil { | ||
| 284 | c.sendError(400, "could not marshal event", false) | ||
| 285 | return nil | ||
| 286 | } | ||
| 287 | |||
| 288 | // Ephemeral events (3000–3999): fanout only, do not store. | ||
| 289 | isEphemeral := event.Kind >= 3000 && event.Kind < 4000 | ||
| 290 | |||
| 291 | if !isEphemeral { | ||
| 292 | // Duplicate check. | ||
| 293 | exists, err := c.store.ExistsByID(ctx, event.ID) | ||
| 294 | if err != nil { | ||
| 295 | c.sendError(500, "internal error", false) | ||
| 296 | return fmt.Errorf("exists check: %w", err) | ||
| 297 | } | ||
| 298 | if exists { | ||
| 299 | c.sendError(409, "duplicate event", false) | ||
| 300 | return nil | ||
| 301 | } | ||
| 302 | |||
| 303 | // Persist. | ||
| 304 | if err := c.store.StoreEvent(ctx, event, envelopeBytes); err != nil { | ||
| 305 | if err == storage.ErrDuplicate { | ||
| 306 | c.sendError(409, "duplicate event", false) | ||
| 307 | return nil | ||
| 308 | } | ||
| 309 | c.sendError(500, "internal error", false) | ||
| 310 | return fmt.Errorf("store event: %w", err) | ||
| 311 | } | ||
| 312 | } | ||
| 313 | |||
| 314 | // Fanout to all matching subscribers. | ||
| 315 | c.global.Fanout(event, envelopeBytes) | ||
| 316 | |||
| 317 | return c.send(MsgTypeOk, OkPayload{Message: "ok"}) | ||
| 318 | } | ||
| 319 | |||
| 320 | // handleSubscribe processes a Subscribe message. | ||
| 321 | func (c *conn) handleSubscribe(ctx context.Context, raw msgpack.RawMessage) error { | ||
| 322 | var p SubscribePayload | ||
| 323 | if err := msgpack.Unmarshal(raw, &p); err != nil { | ||
| 324 | c.sendError(400, "malformed subscribe payload", false) | ||
| 325 | return nil | ||
| 326 | } | ||
| 327 | if p.SubID == "" { | ||
| 328 | c.sendError(400, "sub_id required", false) | ||
| 329 | return nil | ||
| 330 | } | ||
| 331 | |||
| 332 | // Query historical events. | ||
| 333 | envelopes, err := c.store.QueryEvents(ctx, []axon.Filter{p.Filter}) | ||
| 334 | if err != nil { | ||
| 335 | c.sendError(500, "internal error", false) | ||
| 336 | return fmt.Errorf("query events: %w", err) | ||
| 337 | } | ||
| 338 | |||
| 339 | for _, envBytes := range envelopes { | ||
| 340 | ev, err := axon.UnmarshalEvent(envBytes) | ||
| 341 | if err != nil { | ||
| 342 | log.Printf("conn %s: unmarshal stored event: %v", c.id, err) | ||
| 343 | continue | ||
| 344 | } | ||
| 345 | if err := c.send(MsgTypeEvent, EventPayload{SubID: p.SubID, Event: *ev}); err != nil { | ||
| 346 | return err | ||
| 347 | } | ||
| 348 | } | ||
| 349 | |||
| 350 | // Send EOSE. | ||
| 351 | if err := c.send(MsgTypeEose, EosePayload{SubID: p.SubID}); err != nil { | ||
| 352 | return err | ||
| 353 | } | ||
| 354 | |||
| 355 | // Register for live fanout. | ||
| 356 | sub := c.mgr.Add(p.SubID, []axon.Filter{p.Filter}) | ||
| 357 | c.global.Register(c.id, p.SubID, sub) | ||
| 358 | |||
| 359 | // Start goroutine to stream live events to this client. | ||
| 360 | go c.streamSub(sub) | ||
| 361 | |||
| 362 | return nil | ||
| 363 | } | ||
| 364 | |||
| 365 | // streamSub reads from a subscription's Events channel and sends them to the | ||
| 366 | // client. Returns when the subscription or connection is closed. | ||
| 367 | func (c *conn) streamSub(sub *subscription.Subscription) { | ||
| 368 | for envelopeBytes := range sub.Events { | ||
| 369 | ev, err := axon.UnmarshalEvent(envelopeBytes) | ||
| 370 | if err != nil { | ||
| 371 | log.Printf("conn %s: sub %s: unmarshal live event: %v", c.id, sub.ID, err) | ||
| 372 | continue | ||
| 373 | } | ||
| 374 | if err := c.send(MsgTypeEvent, EventPayload{SubID: sub.ID, Event: *ev}); err != nil { | ||
| 375 | return | ||
| 376 | } | ||
| 377 | } | ||
| 378 | } | ||
| 379 | |||
| 380 | // handleUnsubscribe processes an Unsubscribe message. | ||
| 381 | func (c *conn) handleUnsubscribe(raw msgpack.RawMessage) error { | ||
| 382 | var p UnsubscribePayload | ||
| 383 | if err := msgpack.Unmarshal(raw, &p); err != nil { | ||
| 384 | c.sendError(400, "malformed unsubscribe payload", false) | ||
| 385 | return nil | ||
| 386 | } | ||
| 387 | c.mgr.Remove(p.SubID) | ||
| 388 | c.global.Unregister(c.id, p.SubID) | ||
| 389 | return nil | ||
| 390 | } | ||
| 391 | |||
| 392 | // isExpired checks the expires_at tag on a job request event. | ||
| 393 | // Returns (true, nil) if the event is expired, (false, nil) if not, or | ||
| 394 | // (false, err) if the tag is malformed. | ||
| 395 | func isExpired(event *axon.Event) (bool, error) { | ||
| 396 | for _, tag := range event.Tags { | ||
| 397 | if tag.Name != "expires_at" { | ||
| 398 | continue | ||
| 399 | } | ||
| 400 | if len(tag.Values) == 0 { | ||
| 401 | continue | ||
| 402 | } | ||
| 403 | var ts int64 | ||
| 404 | _, err := fmt.Sscanf(tag.Values[0], "%d", &ts) | ||
| 405 | if err != nil { | ||
| 406 | return false, fmt.Errorf("parse expires_at: %w", err) | ||
| 407 | } | ||
| 408 | return time.Now().Unix() > ts, nil | ||
| 409 | } | ||
| 410 | return false, nil | ||
| 411 | } | ||
