From 61a85baf87d89fcc09f9469a113a2ddc982b0a24 Mon Sep 17 00:00:00 2001 From: bndw Date: Mon, 9 Mar 2026 08:01:02 -0700 Subject: feat: phase 2 relay implementation Implement the Axon relay as relay/ (module axon/relay). Includes: - WebSocket framing (RFC 6455, no external deps) in relay/websocket/ - Per-connection auth: challenge/response with ed25519 + allowlist check - Ingest pipeline: sig verify, dedup, ephemeral fanout, SQLite persistence - Subscription manager with prefix-matching filter fanout in relay/subscription/ - SQLite storage with WAL/cache config and UNION query builder in relay/storage/ - Graceful shutdown on SIGINT/SIGTERM - Filter/TagFilter types added to axon core package (required by relay) --- relay/handler.go | 411 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 411 insertions(+) create mode 100644 relay/handler.go (limited to 'relay/handler.go') diff --git a/relay/handler.go b/relay/handler.go new file mode 100644 index 0000000..afd9622 --- /dev/null +++ b/relay/handler.go @@ -0,0 +1,411 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "log" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "axon" + "axon/relay/storage" + "axon/relay/subscription" + ws "axon/relay/websocket" +) + +// Wire protocol message type constants. +const ( + // Client → Relay + MsgTypeAuth uint16 = 1 + MsgTypeSubscribe uint16 = 2 + MsgTypeUnsubscribe uint16 = 3 + MsgTypePublish uint16 = 4 + + // Relay → Client + MsgTypeChallenge uint16 = 10 + MsgTypeEvent uint16 = 11 + MsgTypeEose uint16 = 12 + MsgTypeOk uint16 = 13 + MsgTypeError uint16 = 14 +) + +// Payload types — Relay → Client + +type ChallengePayload struct { + Nonce []byte `msgpack:"nonce"` +} + +type EventPayload struct { + SubID string `msgpack:"sub_id"` + Event axon.Event `msgpack:"event"` +} + +type EosePayload struct { + SubID string `msgpack:"sub_id"` +} + +type OkPayload struct { + Message string `msgpack:"message"` +} + +type ErrorPayload struct { + Code uint16 `msgpack:"code"` + Message string `msgpack:"message"` +} + +// Payload types — Client → Relay + +type AuthPayload struct { + PubKey []byte `msgpack:"pubkey"` + Sig []byte `msgpack:"sig"` +} + +type SubscribePayload struct { + SubID string `msgpack:"sub_id"` + Filter axon.Filter `msgpack:"filter"` +} + +type UnsubscribePayload struct { + SubID string `msgpack:"sub_id"` +} + +type PublishPayload struct { + Event axon.Event `msgpack:"event"` +} + +// conn holds per-connection state. +type conn struct { + id string // unique connection ID (hex nonce for logging) + ws *ws.Conn + store *storage.Storage + global *subscription.GlobalManager + allowlist [][]byte + relayURL string + + authed bool + pubkey []byte + nonce []byte + + mgr *subscription.Manager +} + +// send encodes and writes a wire message to the client. +func (c *conn) send(msgType uint16, payload interface{}) error { + b, err := msgpack.Marshal([]interface{}{msgType, payload}) + if err != nil { + return fmt.Errorf("handler: marshal msg %d: %w", msgType, err) + } + return c.ws.Write(b) +} + +// sendError sends an ErrorPayload. If fatal is true the connection is then closed. +func (c *conn) sendError(code uint16, message string, fatal bool) { + _ = c.send(MsgTypeError, ErrorPayload{Code: code, Message: message}) + if fatal { + c.ws.Close(1000, "") + } +} + +// serve is the main per-connection loop. +func (c *conn) serve(ctx context.Context) { + defer func() { + c.mgr.CloseAll() + c.global.UnregisterConn(c.id) + }() + + // Send challenge immediately on connect. + if err := c.send(MsgTypeChallenge, ChallengePayload{Nonce: c.nonce}); err != nil { + log.Printf("conn %s: send challenge: %v", c.id, err) + return + } + + // Start ping goroutine. + pingStop := make(chan struct{}) + defer close(pingStop) + go c.pingLoop(pingStop) + + for { + raw, err := c.ws.Read(ctx) + if err != nil { + if ctx.Err() != nil { + return + } + // Connection closed or read error. + return + } + + if err := c.dispatch(ctx, raw); err != nil { + log.Printf("conn %s: dispatch: %v", c.id, err) + return + } + } +} + +// pingLoop sends a WebSocket ping every 30 seconds. If two consecutive pings +// go unanswered (no pong received within 30s each) the connection is closed. +func (c *conn) pingLoop(stop <-chan struct{}) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + missed := 0 + for { + select { + case <-stop: + return + case <-ticker.C: + if err := c.ws.Ping(); err != nil { + return + } + missed++ + if missed >= 2 { + log.Printf("conn %s: missed 2 pings, closing", c.id) + c.ws.Close(1001, "ping timeout") + return + } + } + } +} + +// dispatch parses a raw message and routes it to the appropriate handler. +func (c *conn) dispatch(ctx context.Context, raw []byte) error { + // Decode as [uint16, rawPayload] + var arr []msgpack.RawMessage + if err := msgpack.Unmarshal(raw, &arr); err != nil { + c.sendError(400, "malformed message", false) + return nil + } + if len(arr) < 2 { + c.sendError(400, "message too short", false) + return nil + } + + var msgType uint16 + if err := msgpack.Unmarshal(arr[0], &msgType); err != nil { + c.sendError(400, "invalid message type", false) + return nil + } + + // Only MsgTypeAuth is allowed before authentication. + if !c.authed && msgType != MsgTypeAuth { + c.sendError(401, "not authenticated", true) + return fmt.Errorf("unauthenticated message type %d", msgType) + } + + switch msgType { + case MsgTypeAuth: + return c.handleAuth(arr[1]) + case MsgTypePublish: + return c.handlePublish(ctx, arr[1]) + case MsgTypeSubscribe: + return c.handleSubscribe(ctx, arr[1]) + case MsgTypeUnsubscribe: + return c.handleUnsubscribe(arr[1]) + default: + c.sendError(400, fmt.Sprintf("unknown message type %d", msgType), false) + } + return nil +} + +// handleAuth processes an Auth message. +func (c *conn) handleAuth(raw msgpack.RawMessage) error { + var p AuthPayload + if err := msgpack.Unmarshal(raw, &p); err != nil { + c.sendError(400, "malformed auth payload", true) + return fmt.Errorf("unmarshal auth: %w", err) + } + + if len(p.PubKey) != 32 { + c.sendError(400, "pubkey must be 32 bytes", true) + return fmt.Errorf("bad pubkey length %d", len(p.PubKey)) + } + + if !axon.VerifyChallenge(p.PubKey, c.nonce, c.relayURL, p.Sig) { + c.sendError(401, "invalid signature", true) + return fmt.Errorf("auth signature invalid") + } + + // Check allowlist. + if len(c.allowlist) > 0 { + allowed := false + for _, pk := range c.allowlist { + if bytes.Equal(pk, p.PubKey) { + allowed = true + break + } + } + if !allowed { + c.sendError(403, "pubkey not in allowlist", true) + return fmt.Errorf("pubkey not in allowlist") + } + } + + c.authed = true + c.pubkey = p.PubKey + return c.send(MsgTypeOk, OkPayload{Message: "authenticated"}) +} + +// handlePublish processes a Publish message. +func (c *conn) handlePublish(ctx context.Context, raw msgpack.RawMessage) error { + var p PublishPayload + if err := msgpack.Unmarshal(raw, &p); err != nil { + c.sendError(400, "malformed publish payload", false) + return nil + } + + event := &p.Event + + // Content length check. + if len(event.Content) > 65536 { + c.sendError(413, "content exceeds 64KB limit", false) + return nil + } + + // Signature verification. + if err := axon.Verify(event); err != nil { + c.sendError(400, fmt.Sprintf("invalid event: %v", err), false) + return nil + } + + // Job request expiry check (kinds 5000–5999). + if event.Kind >= 5000 && event.Kind < 6000 { + if expired, err := isExpired(event); err != nil { + c.sendError(400, fmt.Sprintf("bad expires_at tag: %v", err), false) + return nil + } else if expired { + c.sendError(400, "job request has expired", false) + return nil + } + } + + // Marshal envelope bytes for storage and fanout. + envelopeBytes, err := axon.MarshalEvent(event) + if err != nil { + c.sendError(400, "could not marshal event", false) + return nil + } + + // Ephemeral events (3000–3999): fanout only, do not store. + isEphemeral := event.Kind >= 3000 && event.Kind < 4000 + + if !isEphemeral { + // Duplicate check. + exists, err := c.store.ExistsByID(ctx, event.ID) + if err != nil { + c.sendError(500, "internal error", false) + return fmt.Errorf("exists check: %w", err) + } + if exists { + c.sendError(409, "duplicate event", false) + return nil + } + + // Persist. + if err := c.store.StoreEvent(ctx, event, envelopeBytes); err != nil { + if err == storage.ErrDuplicate { + c.sendError(409, "duplicate event", false) + return nil + } + c.sendError(500, "internal error", false) + return fmt.Errorf("store event: %w", err) + } + } + + // Fanout to all matching subscribers. + c.global.Fanout(event, envelopeBytes) + + return c.send(MsgTypeOk, OkPayload{Message: "ok"}) +} + +// handleSubscribe processes a Subscribe message. +func (c *conn) handleSubscribe(ctx context.Context, raw msgpack.RawMessage) error { + var p SubscribePayload + if err := msgpack.Unmarshal(raw, &p); err != nil { + c.sendError(400, "malformed subscribe payload", false) + return nil + } + if p.SubID == "" { + c.sendError(400, "sub_id required", false) + return nil + } + + // Query historical events. + envelopes, err := c.store.QueryEvents(ctx, []axon.Filter{p.Filter}) + if err != nil { + c.sendError(500, "internal error", false) + return fmt.Errorf("query events: %w", err) + } + + for _, envBytes := range envelopes { + ev, err := axon.UnmarshalEvent(envBytes) + if err != nil { + log.Printf("conn %s: unmarshal stored event: %v", c.id, err) + continue + } + if err := c.send(MsgTypeEvent, EventPayload{SubID: p.SubID, Event: *ev}); err != nil { + return err + } + } + + // Send EOSE. + if err := c.send(MsgTypeEose, EosePayload{SubID: p.SubID}); err != nil { + return err + } + + // Register for live fanout. + sub := c.mgr.Add(p.SubID, []axon.Filter{p.Filter}) + c.global.Register(c.id, p.SubID, sub) + + // Start goroutine to stream live events to this client. + go c.streamSub(sub) + + return nil +} + +// streamSub reads from a subscription's Events channel and sends them to the +// client. Returns when the subscription or connection is closed. +func (c *conn) streamSub(sub *subscription.Subscription) { + for envelopeBytes := range sub.Events { + ev, err := axon.UnmarshalEvent(envelopeBytes) + if err != nil { + log.Printf("conn %s: sub %s: unmarshal live event: %v", c.id, sub.ID, err) + continue + } + if err := c.send(MsgTypeEvent, EventPayload{SubID: sub.ID, Event: *ev}); err != nil { + return + } + } +} + +// handleUnsubscribe processes an Unsubscribe message. +func (c *conn) handleUnsubscribe(raw msgpack.RawMessage) error { + var p UnsubscribePayload + if err := msgpack.Unmarshal(raw, &p); err != nil { + c.sendError(400, "malformed unsubscribe payload", false) + return nil + } + c.mgr.Remove(p.SubID) + c.global.Unregister(c.id, p.SubID) + return nil +} + +// isExpired checks the expires_at tag on a job request event. +// Returns (true, nil) if the event is expired, (false, nil) if not, or +// (false, err) if the tag is malformed. +func isExpired(event *axon.Event) (bool, error) { + for _, tag := range event.Tags { + if tag.Name != "expires_at" { + continue + } + if len(tag.Values) == 0 { + continue + } + var ts int64 + _, err := fmt.Sscanf(tag.Values[0], "%d", &ts) + if err != nil { + return false, fmt.Errorf("parse expires_at: %w", err) + } + return time.Now().Unix() > ts, nil + } + return false, nil +} -- cgit v1.2.3