package main import ( "bytes" "context" "fmt" "log" "time" "github.com/vmihailenco/msgpack/v5" "code.northwest.io/axon" "code.northwest.io/axon/relay/storage" "code.northwest.io/axon/relay/subscription" ws "code.northwest.io/axon/relay/websocket" ) // Wire protocol message type constants. const ( // Client → Relay MsgTypeAuth uint16 = 1 MsgTypeSubscribe uint16 = 2 MsgTypeUnsubscribe uint16 = 3 MsgTypePublish uint16 = 4 // Relay → Client MsgTypeChallenge uint16 = 10 MsgTypeEvent uint16 = 11 MsgTypeEose uint16 = 12 MsgTypeOk uint16 = 13 MsgTypeError uint16 = 14 ) // Payload types — Relay → Client type ChallengePayload struct { Nonce []byte `msgpack:"nonce"` } type EventPayload struct { SubID string `msgpack:"sub_id"` Event axon.Event `msgpack:"event"` } type EosePayload struct { SubID string `msgpack:"sub_id"` } type OkPayload struct { Message string `msgpack:"message"` } type ErrorPayload struct { Code uint16 `msgpack:"code"` Message string `msgpack:"message"` } // Payload types — Client → Relay type AuthPayload struct { PubKey []byte `msgpack:"pubkey"` Sig []byte `msgpack:"sig"` } type SubscribePayload struct { SubID string `msgpack:"sub_id"` Filter axon.Filter `msgpack:"filter"` } type UnsubscribePayload struct { SubID string `msgpack:"sub_id"` } type PublishPayload struct { Event axon.Event `msgpack:"event"` } // conn holds per-connection state. type conn struct { id string // unique connection ID (hex nonce for logging) ws *ws.Conn store *storage.Storage global *subscription.GlobalManager allowlist [][]byte relayURL string authed bool pubkey []byte nonce []byte mgr *subscription.Manager } // send encodes and writes a wire message to the client. func (c *conn) send(msgType uint16, payload interface{}) error { b, err := msgpack.Marshal([]interface{}{msgType, payload}) if err != nil { return fmt.Errorf("handler: marshal msg %d: %w", msgType, err) } return c.ws.Write(b) } // sendError sends an ErrorPayload. If fatal is true the connection is then closed. func (c *conn) sendError(code uint16, message string, fatal bool) { _ = c.send(MsgTypeError, ErrorPayload{Code: code, Message: message}) if fatal { c.ws.Close(1000, "") } } // serve is the main per-connection loop. func (c *conn) serve(ctx context.Context) { defer func() { c.mgr.CloseAll() c.global.UnregisterConn(c.id) }() // Send challenge immediately on connect. if err := c.send(MsgTypeChallenge, ChallengePayload{Nonce: c.nonce}); err != nil { log.Printf("conn %s: send challenge: %v", c.id, err) return } // Start ping goroutine. pingStop := make(chan struct{}) defer close(pingStop) go c.pingLoop(pingStop) for { raw, err := c.ws.Read(ctx) if err != nil { if ctx.Err() != nil { return } // Connection closed or read error. return } if err := c.dispatch(ctx, raw); err != nil { log.Printf("conn %s: dispatch: %v", c.id, err) return } } } // pingLoop sends a WebSocket ping every 30 seconds. If two consecutive pings // go unanswered (no pong received within 30s each) the connection is closed. func (c *conn) pingLoop(stop <-chan struct{}) { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() missed := 0 for { select { case <-stop: return case <-ticker.C: if err := c.ws.Ping(); err != nil { return } missed++ if missed >= 2 { log.Printf("conn %s: missed 2 pings, closing", c.id) c.ws.Close(1001, "ping timeout") return } } } } // dispatch parses a raw message and routes it to the appropriate handler. func (c *conn) dispatch(ctx context.Context, raw []byte) error { // Decode as [uint16, rawPayload] var arr []msgpack.RawMessage if err := msgpack.Unmarshal(raw, &arr); err != nil { c.sendError(400, "malformed message", false) return nil } if len(arr) < 2 { c.sendError(400, "message too short", false) return nil } var msgType uint16 if err := msgpack.Unmarshal(arr[0], &msgType); err != nil { c.sendError(400, "invalid message type", false) return nil } // Only MsgTypeAuth is allowed before authentication. if !c.authed && msgType != MsgTypeAuth { c.sendError(401, "not authenticated", true) return fmt.Errorf("unauthenticated message type %d", msgType) } switch msgType { case MsgTypeAuth: return c.handleAuth(arr[1]) case MsgTypePublish: return c.handlePublish(ctx, arr[1]) case MsgTypeSubscribe: return c.handleSubscribe(ctx, arr[1]) case MsgTypeUnsubscribe: return c.handleUnsubscribe(arr[1]) default: c.sendError(400, fmt.Sprintf("unknown message type %d", msgType), false) } return nil } // handleAuth processes an Auth message. func (c *conn) handleAuth(raw msgpack.RawMessage) error { var p AuthPayload if err := msgpack.Unmarshal(raw, &p); err != nil { c.sendError(400, "malformed auth payload", true) return fmt.Errorf("unmarshal auth: %w", err) } if len(p.PubKey) != 32 { c.sendError(400, "pubkey must be 32 bytes", true) return fmt.Errorf("bad pubkey length %d", len(p.PubKey)) } if !axon.VerifyChallenge(p.PubKey, c.nonce, c.relayURL, p.Sig) { c.sendError(401, "invalid signature", true) return fmt.Errorf("auth signature invalid") } // Check allowlist. if len(c.allowlist) > 0 { allowed := false for _, pk := range c.allowlist { if bytes.Equal(pk, p.PubKey) { allowed = true break } } if !allowed { c.sendError(403, "pubkey not in allowlist", true) return fmt.Errorf("pubkey not in allowlist") } } c.authed = true c.pubkey = p.PubKey return c.send(MsgTypeOk, OkPayload{Message: "authenticated"}) } // handlePublish processes a Publish message. func (c *conn) handlePublish(ctx context.Context, raw msgpack.RawMessage) error { var p PublishPayload if err := msgpack.Unmarshal(raw, &p); err != nil { c.sendError(400, "malformed publish payload", false) return nil } event := &p.Event // Content length check. if len(event.Content) > 65536 { c.sendError(413, "content exceeds 64KB limit", false) return nil } // Signature verification. if err := axon.Verify(event); err != nil { c.sendError(400, fmt.Sprintf("invalid event: %v", err), false) return nil } // Job request expiry check (kinds 5000–5999). if event.Kind >= 5000 && event.Kind < 6000 { if expired, err := isExpired(event); err != nil { c.sendError(400, fmt.Sprintf("bad expires_at tag: %v", err), false) return nil } else if expired { c.sendError(400, "job request has expired", false) return nil } } // Marshal envelope bytes for storage and fanout. envelopeBytes, err := axon.MarshalEvent(event) if err != nil { c.sendError(400, "could not marshal event", false) return nil } // Ephemeral events (3000–3999): fanout only, do not store. isEphemeral := event.Kind >= 3000 && event.Kind < 4000 if !isEphemeral { // Duplicate check. exists, err := c.store.ExistsByID(ctx, event.ID) if err != nil { c.sendError(500, "internal error", false) return fmt.Errorf("exists check: %w", err) } if exists { c.sendError(409, "duplicate event", false) return nil } // Persist. if err := c.store.StoreEvent(ctx, event, envelopeBytes); err != nil { if err == storage.ErrDuplicate { c.sendError(409, "duplicate event", false) return nil } c.sendError(500, "internal error", false) return fmt.Errorf("store event: %w", err) } } // Fanout to all matching subscribers. c.global.Fanout(event, envelopeBytes) return c.send(MsgTypeOk, OkPayload{Message: "ok"}) } // handleSubscribe processes a Subscribe message. func (c *conn) handleSubscribe(ctx context.Context, raw msgpack.RawMessage) error { var p SubscribePayload if err := msgpack.Unmarshal(raw, &p); err != nil { c.sendError(400, "malformed subscribe payload", false) return nil } if p.SubID == "" { c.sendError(400, "sub_id required", false) return nil } // Query historical events. envelopes, err := c.store.QueryEvents(ctx, []axon.Filter{p.Filter}) if err != nil { c.sendError(500, "internal error", false) return fmt.Errorf("query events: %w", err) } for _, envBytes := range envelopes { ev, err := axon.UnmarshalEvent(envBytes) if err != nil { log.Printf("conn %s: unmarshal stored event: %v", c.id, err) continue } if err := c.send(MsgTypeEvent, EventPayload{SubID: p.SubID, Event: *ev}); err != nil { return err } } // Send EOSE. if err := c.send(MsgTypeEose, EosePayload{SubID: p.SubID}); err != nil { return err } // Register for live fanout. sub := c.mgr.Add(p.SubID, []axon.Filter{p.Filter}) c.global.Register(c.id, p.SubID, sub) // Start goroutine to stream live events to this client. go c.streamSub(sub) return nil } // streamSub reads from a subscription's Events channel and sends them to the // client. Returns when the subscription or connection is closed. func (c *conn) streamSub(sub *subscription.Subscription) { for envelopeBytes := range sub.Events { ev, err := axon.UnmarshalEvent(envelopeBytes) if err != nil { log.Printf("conn %s: sub %s: unmarshal live event: %v", c.id, sub.ID, err) continue } if err := c.send(MsgTypeEvent, EventPayload{SubID: sub.ID, Event: *ev}); err != nil { return } } } // handleUnsubscribe processes an Unsubscribe message. func (c *conn) handleUnsubscribe(raw msgpack.RawMessage) error { var p UnsubscribePayload if err := msgpack.Unmarshal(raw, &p); err != nil { c.sendError(400, "malformed unsubscribe payload", false) return nil } c.mgr.Remove(p.SubID) c.global.Unregister(c.id, p.SubID) return nil } // isExpired checks the expires_at tag on a job request event. // Returns (true, nil) if the event is expired, (false, nil) if not, or // (false, err) if the tag is malformed. func isExpired(event *axon.Event) (bool, error) { for _, tag := range event.Tags { if tag.Name != "expires_at" { continue } if len(tag.Values) == 0 { continue } var ts int64 _, err := fmt.Sscanf(tag.Values[0], "%d", &ts) if err != nil { return false, fmt.Errorf("parse expires_at: %w", err) } return time.Now().Unix() > ts, nil } return false, nil }