From 61a85baf87d89fcc09f9469a113a2ddc982b0a24 Mon Sep 17 00:00:00 2001 From: bndw Date: Mon, 9 Mar 2026 08:01:02 -0700 Subject: feat: phase 2 relay implementation Implement the Axon relay as relay/ (module axon/relay). Includes: - WebSocket framing (RFC 6455, no external deps) in relay/websocket/ - Per-connection auth: challenge/response with ed25519 + allowlist check - Ingest pipeline: sig verify, dedup, ephemeral fanout, SQLite persistence - Subscription manager with prefix-matching filter fanout in relay/subscription/ - SQLite storage with WAL/cache config and UNION query builder in relay/storage/ - Graceful shutdown on SIGINT/SIGTERM - Filter/TagFilter types added to axon core package (required by relay) --- axon.go | 22 +++ relay/config.go | 66 +++++++ relay/go.mod | 30 +++ relay/go.sum | 63 +++++++ relay/handler.go | 411 ++++++++++++++++++++++++++++++++++++++++++ relay/main.go | 72 ++++++++ relay/server.go | 118 ++++++++++++ relay/storage/events.go | 192 ++++++++++++++++++++ relay/storage/storage.go | 86 +++++++++ relay/subscription/manager.go | 317 ++++++++++++++++++++++++++++++++ relay/websocket/websocket.go | 244 +++++++++++++++++++++++++ 11 files changed, 1621 insertions(+) create mode 100644 relay/config.go create mode 100644 relay/go.mod create mode 100644 relay/go.sum create mode 100644 relay/handler.go create mode 100644 relay/main.go create mode 100644 relay/server.go create mode 100644 relay/storage/events.go create mode 100644 relay/storage/storage.go create mode 100644 relay/subscription/manager.go create mode 100644 relay/websocket/websocket.go diff --git a/axon.go b/axon.go index 51ec22d..ebe18a2 100644 --- a/axon.go +++ b/axon.go @@ -29,6 +29,28 @@ type Tag struct { Values []string `msgpack:"values" json:"values"` } +// TagFilter selects events that have a tag with the given name and any of the +// given values. An empty Values slice matches any value. +type TagFilter struct { + Name string `msgpack:"name"` + Values []string `msgpack:"values"` +} + +// Filter selects a subset of events. All non-empty fields are ANDed together; +// multiple entries within a slice field are ORed. +// +// IDs and Authors support prefix matching: a []byte shorter than 32 bytes +// matches any event whose ID (or pubkey) starts with those bytes. +type Filter struct { + IDs [][]byte `msgpack:"ids"` + Authors [][]byte `msgpack:"authors"` + Kinds []uint16 `msgpack:"kinds"` + Since int64 `msgpack:"since"` // inclusive lower bound on created_at + Until int64 `msgpack:"until"` // inclusive upper bound on created_at + Limit int32 `msgpack:"limit"` // max events to return (0 = no limit) + Tags []TagFilter `msgpack:"tags"` +} + // Event is the core Axon data structure. All fields use their wire types. // id, pubkey and sig are raw 32/64-byte slices, not hex. // content is opaque bytes (msgpack bin type). diff --git a/relay/config.go b/relay/config.go new file mode 100644 index 0000000..e432b85 --- /dev/null +++ b/relay/config.go @@ -0,0 +1,66 @@ +package main + +import ( + "encoding/hex" + "fmt" + "os" + + "gopkg.in/yaml.v3" +) + +// Config holds all relay configuration loaded from config.yaml. +type Config struct { + Addr string `yaml:"addr"` + DB string `yaml:"db"` + RelayURL string `yaml:"relay_url"` + Allowlist []string `yaml:"allowlist"` // hex-encoded pubkeys +} + +// DefaultConfig returns sensible defaults. +func DefaultConfig() Config { + return Config{ + Addr: ":8080", + DB: "axon.db", + RelayURL: "ws://localhost:8080", + } +} + +// AllowlistBytes decodes the hex pubkeys in c.Allowlist and returns them as +// raw byte slices. Returns an error if any entry is not valid 64-char hex. +func (c *Config) AllowlistBytes() ([][]byte, error) { + out := make([][]byte, 0, len(c.Allowlist)) + for _, h := range c.Allowlist { + b, err := hex.DecodeString(h) + if err != nil { + return nil, fmt.Errorf("config: allowlist entry %q is not valid hex: %w", h, err) + } + if len(b) != 32 { + return nil, fmt.Errorf("config: allowlist entry %q decoded to %d bytes, want 32", h, len(b)) + } + out = append(out, b) + } + return out, nil +} + +// LoadConfig reads and parses a YAML config file. Missing fields fall back to +// DefaultConfig values. +func LoadConfig(path string) (Config, error) { + cfg := DefaultConfig() + + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + // No config file — use defaults. + return cfg, nil + } + return cfg, fmt.Errorf("config: open %q: %w", path, err) + } + defer f.Close() + + dec := yaml.NewDecoder(f) + dec.KnownFields(true) + if err := dec.Decode(&cfg); err != nil { + return cfg, fmt.Errorf("config: decode %q: %w", path, err) + } + return cfg, nil +} diff --git a/relay/go.mod b/relay/go.mod new file mode 100644 index 0000000..a3d424a --- /dev/null +++ b/relay/go.mod @@ -0,0 +1,30 @@ +module axon/relay + +go 1.25.5 + +require ( + axon v0.0.0 + github.com/vmihailenco/msgpack/v5 v5.4.1 + gopkg.in/yaml.v3 v3.0.1 + modernc.org/sqlite v1.33.1 +) + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + golang.org/x/crypto v0.48.0 // indirect + golang.org/x/sys v0.41.0 // indirect + modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect + modernc.org/libc v1.55.3 // indirect + modernc.org/mathutil v1.6.0 // indirect + modernc.org/memory v1.8.0 // indirect + modernc.org/strutil v1.2.0 // indirect + modernc.org/token v1.1.0 // indirect +) + +replace axon => ../ diff --git a/relay/go.sum b/relay/go.sum new file mode 100644 index 0000000..04de341 --- /dev/null +++ b/relay/go.sum @@ -0,0 +1,63 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ= +modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= +modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y= +modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw= +modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU= +modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= +modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= +modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= +modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w= +modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= +modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= +modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= +modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc= +modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss= +modernc.org/sqlite v1.33.1 h1:trb6Z3YYoeM9eDL1O8do81kP+0ejv+YzgyFo+Gwy0nM= +modernc.org/sqlite v1.33.1/go.mod h1:pXV2xHxhzXZsgT/RtTFAPY6JJDEvOTcTdwADQCCWD4k= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/relay/handler.go b/relay/handler.go new file mode 100644 index 0000000..afd9622 --- /dev/null +++ b/relay/handler.go @@ -0,0 +1,411 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "log" + "time" + + "github.com/vmihailenco/msgpack/v5" + + "axon" + "axon/relay/storage" + "axon/relay/subscription" + ws "axon/relay/websocket" +) + +// Wire protocol message type constants. +const ( + // Client → Relay + MsgTypeAuth uint16 = 1 + MsgTypeSubscribe uint16 = 2 + MsgTypeUnsubscribe uint16 = 3 + MsgTypePublish uint16 = 4 + + // Relay → Client + MsgTypeChallenge uint16 = 10 + MsgTypeEvent uint16 = 11 + MsgTypeEose uint16 = 12 + MsgTypeOk uint16 = 13 + MsgTypeError uint16 = 14 +) + +// Payload types — Relay → Client + +type ChallengePayload struct { + Nonce []byte `msgpack:"nonce"` +} + +type EventPayload struct { + SubID string `msgpack:"sub_id"` + Event axon.Event `msgpack:"event"` +} + +type EosePayload struct { + SubID string `msgpack:"sub_id"` +} + +type OkPayload struct { + Message string `msgpack:"message"` +} + +type ErrorPayload struct { + Code uint16 `msgpack:"code"` + Message string `msgpack:"message"` +} + +// Payload types — Client → Relay + +type AuthPayload struct { + PubKey []byte `msgpack:"pubkey"` + Sig []byte `msgpack:"sig"` +} + +type SubscribePayload struct { + SubID string `msgpack:"sub_id"` + Filter axon.Filter `msgpack:"filter"` +} + +type UnsubscribePayload struct { + SubID string `msgpack:"sub_id"` +} + +type PublishPayload struct { + Event axon.Event `msgpack:"event"` +} + +// conn holds per-connection state. +type conn struct { + id string // unique connection ID (hex nonce for logging) + ws *ws.Conn + store *storage.Storage + global *subscription.GlobalManager + allowlist [][]byte + relayURL string + + authed bool + pubkey []byte + nonce []byte + + mgr *subscription.Manager +} + +// send encodes and writes a wire message to the client. +func (c *conn) send(msgType uint16, payload interface{}) error { + b, err := msgpack.Marshal([]interface{}{msgType, payload}) + if err != nil { + return fmt.Errorf("handler: marshal msg %d: %w", msgType, err) + } + return c.ws.Write(b) +} + +// sendError sends an ErrorPayload. If fatal is true the connection is then closed. +func (c *conn) sendError(code uint16, message string, fatal bool) { + _ = c.send(MsgTypeError, ErrorPayload{Code: code, Message: message}) + if fatal { + c.ws.Close(1000, "") + } +} + +// serve is the main per-connection loop. +func (c *conn) serve(ctx context.Context) { + defer func() { + c.mgr.CloseAll() + c.global.UnregisterConn(c.id) + }() + + // Send challenge immediately on connect. + if err := c.send(MsgTypeChallenge, ChallengePayload{Nonce: c.nonce}); err != nil { + log.Printf("conn %s: send challenge: %v", c.id, err) + return + } + + // Start ping goroutine. + pingStop := make(chan struct{}) + defer close(pingStop) + go c.pingLoop(pingStop) + + for { + raw, err := c.ws.Read(ctx) + if err != nil { + if ctx.Err() != nil { + return + } + // Connection closed or read error. + return + } + + if err := c.dispatch(ctx, raw); err != nil { + log.Printf("conn %s: dispatch: %v", c.id, err) + return + } + } +} + +// pingLoop sends a WebSocket ping every 30 seconds. If two consecutive pings +// go unanswered (no pong received within 30s each) the connection is closed. +func (c *conn) pingLoop(stop <-chan struct{}) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + missed := 0 + for { + select { + case <-stop: + return + case <-ticker.C: + if err := c.ws.Ping(); err != nil { + return + } + missed++ + if missed >= 2 { + log.Printf("conn %s: missed 2 pings, closing", c.id) + c.ws.Close(1001, "ping timeout") + return + } + } + } +} + +// dispatch parses a raw message and routes it to the appropriate handler. +func (c *conn) dispatch(ctx context.Context, raw []byte) error { + // Decode as [uint16, rawPayload] + var arr []msgpack.RawMessage + if err := msgpack.Unmarshal(raw, &arr); err != nil { + c.sendError(400, "malformed message", false) + return nil + } + if len(arr) < 2 { + c.sendError(400, "message too short", false) + return nil + } + + var msgType uint16 + if err := msgpack.Unmarshal(arr[0], &msgType); err != nil { + c.sendError(400, "invalid message type", false) + return nil + } + + // Only MsgTypeAuth is allowed before authentication. + if !c.authed && msgType != MsgTypeAuth { + c.sendError(401, "not authenticated", true) + return fmt.Errorf("unauthenticated message type %d", msgType) + } + + switch msgType { + case MsgTypeAuth: + return c.handleAuth(arr[1]) + case MsgTypePublish: + return c.handlePublish(ctx, arr[1]) + case MsgTypeSubscribe: + return c.handleSubscribe(ctx, arr[1]) + case MsgTypeUnsubscribe: + return c.handleUnsubscribe(arr[1]) + default: + c.sendError(400, fmt.Sprintf("unknown message type %d", msgType), false) + } + return nil +} + +// handleAuth processes an Auth message. +func (c *conn) handleAuth(raw msgpack.RawMessage) error { + var p AuthPayload + if err := msgpack.Unmarshal(raw, &p); err != nil { + c.sendError(400, "malformed auth payload", true) + return fmt.Errorf("unmarshal auth: %w", err) + } + + if len(p.PubKey) != 32 { + c.sendError(400, "pubkey must be 32 bytes", true) + return fmt.Errorf("bad pubkey length %d", len(p.PubKey)) + } + + if !axon.VerifyChallenge(p.PubKey, c.nonce, c.relayURL, p.Sig) { + c.sendError(401, "invalid signature", true) + return fmt.Errorf("auth signature invalid") + } + + // Check allowlist. + if len(c.allowlist) > 0 { + allowed := false + for _, pk := range c.allowlist { + if bytes.Equal(pk, p.PubKey) { + allowed = true + break + } + } + if !allowed { + c.sendError(403, "pubkey not in allowlist", true) + return fmt.Errorf("pubkey not in allowlist") + } + } + + c.authed = true + c.pubkey = p.PubKey + return c.send(MsgTypeOk, OkPayload{Message: "authenticated"}) +} + +// handlePublish processes a Publish message. +func (c *conn) handlePublish(ctx context.Context, raw msgpack.RawMessage) error { + var p PublishPayload + if err := msgpack.Unmarshal(raw, &p); err != nil { + c.sendError(400, "malformed publish payload", false) + return nil + } + + event := &p.Event + + // Content length check. + if len(event.Content) > 65536 { + c.sendError(413, "content exceeds 64KB limit", false) + return nil + } + + // Signature verification. + if err := axon.Verify(event); err != nil { + c.sendError(400, fmt.Sprintf("invalid event: %v", err), false) + return nil + } + + // Job request expiry check (kinds 5000–5999). + if event.Kind >= 5000 && event.Kind < 6000 { + if expired, err := isExpired(event); err != nil { + c.sendError(400, fmt.Sprintf("bad expires_at tag: %v", err), false) + return nil + } else if expired { + c.sendError(400, "job request has expired", false) + return nil + } + } + + // Marshal envelope bytes for storage and fanout. + envelopeBytes, err := axon.MarshalEvent(event) + if err != nil { + c.sendError(400, "could not marshal event", false) + return nil + } + + // Ephemeral events (3000–3999): fanout only, do not store. + isEphemeral := event.Kind >= 3000 && event.Kind < 4000 + + if !isEphemeral { + // Duplicate check. + exists, err := c.store.ExistsByID(ctx, event.ID) + if err != nil { + c.sendError(500, "internal error", false) + return fmt.Errorf("exists check: %w", err) + } + if exists { + c.sendError(409, "duplicate event", false) + return nil + } + + // Persist. + if err := c.store.StoreEvent(ctx, event, envelopeBytes); err != nil { + if err == storage.ErrDuplicate { + c.sendError(409, "duplicate event", false) + return nil + } + c.sendError(500, "internal error", false) + return fmt.Errorf("store event: %w", err) + } + } + + // Fanout to all matching subscribers. + c.global.Fanout(event, envelopeBytes) + + return c.send(MsgTypeOk, OkPayload{Message: "ok"}) +} + +// handleSubscribe processes a Subscribe message. +func (c *conn) handleSubscribe(ctx context.Context, raw msgpack.RawMessage) error { + var p SubscribePayload + if err := msgpack.Unmarshal(raw, &p); err != nil { + c.sendError(400, "malformed subscribe payload", false) + return nil + } + if p.SubID == "" { + c.sendError(400, "sub_id required", false) + return nil + } + + // Query historical events. + envelopes, err := c.store.QueryEvents(ctx, []axon.Filter{p.Filter}) + if err != nil { + c.sendError(500, "internal error", false) + return fmt.Errorf("query events: %w", err) + } + + for _, envBytes := range envelopes { + ev, err := axon.UnmarshalEvent(envBytes) + if err != nil { + log.Printf("conn %s: unmarshal stored event: %v", c.id, err) + continue + } + if err := c.send(MsgTypeEvent, EventPayload{SubID: p.SubID, Event: *ev}); err != nil { + return err + } + } + + // Send EOSE. + if err := c.send(MsgTypeEose, EosePayload{SubID: p.SubID}); err != nil { + return err + } + + // Register for live fanout. + sub := c.mgr.Add(p.SubID, []axon.Filter{p.Filter}) + c.global.Register(c.id, p.SubID, sub) + + // Start goroutine to stream live events to this client. + go c.streamSub(sub) + + return nil +} + +// streamSub reads from a subscription's Events channel and sends them to the +// client. Returns when the subscription or connection is closed. +func (c *conn) streamSub(sub *subscription.Subscription) { + for envelopeBytes := range sub.Events { + ev, err := axon.UnmarshalEvent(envelopeBytes) + if err != nil { + log.Printf("conn %s: sub %s: unmarshal live event: %v", c.id, sub.ID, err) + continue + } + if err := c.send(MsgTypeEvent, EventPayload{SubID: sub.ID, Event: *ev}); err != nil { + return + } + } +} + +// handleUnsubscribe processes an Unsubscribe message. +func (c *conn) handleUnsubscribe(raw msgpack.RawMessage) error { + var p UnsubscribePayload + if err := msgpack.Unmarshal(raw, &p); err != nil { + c.sendError(400, "malformed unsubscribe payload", false) + return nil + } + c.mgr.Remove(p.SubID) + c.global.Unregister(c.id, p.SubID) + return nil +} + +// isExpired checks the expires_at tag on a job request event. +// Returns (true, nil) if the event is expired, (false, nil) if not, or +// (false, err) if the tag is malformed. +func isExpired(event *axon.Event) (bool, error) { + for _, tag := range event.Tags { + if tag.Name != "expires_at" { + continue + } + if len(tag.Values) == 0 { + continue + } + var ts int64 + _, err := fmt.Sscanf(tag.Values[0], "%d", &ts) + if err != nil { + return false, fmt.Errorf("parse expires_at: %w", err) + } + return time.Now().Unix() > ts, nil + } + return false, nil +} diff --git a/relay/main.go b/relay/main.go new file mode 100644 index 0000000..2cfa034 --- /dev/null +++ b/relay/main.go @@ -0,0 +1,72 @@ +package main + +import ( + "context" + "errors" + "flag" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "axon/relay/storage" + "axon/relay/subscription" +) + +func main() { + cfgPath := flag.String("config", "config.yaml", "path to config.yaml") + flag.Parse() + + cfg, err := LoadConfig(*cfgPath) + if err != nil { + log.Fatalf("relay: load config: %v", err) + } + + allowlist, err := cfg.AllowlistBytes() + if err != nil { + log.Fatalf("relay: allowlist: %v", err) + } + + store, err := storage.New(cfg.DB) + if err != nil { + log.Fatalf("relay: open storage: %v", err) + } + defer store.Close() + + global := subscription.NewGlobalManager() + + // Periodically purge closed subscriptions. + stopPurger := make(chan struct{}) + global.StartPurger(5*time.Minute, stopPurger) + defer close(stopPurger) + + srv := NewServer(cfg, allowlist, store, global) + + // Graceful shutdown on SIGINT / SIGTERM. + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + errCh := make(chan error, 1) + go func() { + if err := srv.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() + + select { + case sig := <-sigCh: + log.Printf("relay: received signal %s, shutting down", sig) + case err := <-errCh: + log.Fatalf("relay: server error: %v", err) + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := srv.Shutdown(shutdownCtx); err != nil { + log.Printf("relay: shutdown error: %v", err) + } + log.Println("relay: stopped") +} diff --git a/relay/server.go b/relay/server.go new file mode 100644 index 0000000..085929c --- /dev/null +++ b/relay/server.go @@ -0,0 +1,118 @@ +package main + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "log" + "net/http" + "sync" + + "axon/relay/storage" + "axon/relay/subscription" + ws "axon/relay/websocket" +) + +// Server is the HTTP + WebSocket server for the Axon relay. +type Server struct { + cfg Config + allowlist [][]byte + store *storage.Storage + global *subscription.GlobalManager + + mu sync.WaitGroup + httpSrv *http.Server +} + +// NewServer creates a Server from the given config. +func NewServer(cfg Config, allowlist [][]byte, store *storage.Storage, global *subscription.GlobalManager) *Server { + return &Server{ + cfg: cfg, + allowlist: allowlist, + store: store, + global: global, + } +} + +// Start configures the HTTP server and starts listening. Call Shutdown to stop. +func (s *Server) Start() error { + mux := http.NewServeMux() + mux.HandleFunc("/", s.handleWS) + + s.httpSrv = &http.Server{ + Addr: s.cfg.Addr, + Handler: mux, + } + + log.Printf("relay: listening on %s", s.cfg.Addr) + return s.httpSrv.ListenAndServe() +} + +// Shutdown gracefully stops the server and waits for all connections to drain. +func (s *Server) Shutdown(ctx context.Context) error { + err := s.httpSrv.Shutdown(ctx) + // Wait for all handler goroutines to finish. + done := make(chan struct{}) + go func() { + s.mu.Wait() + close(done) + }() + select { + case <-done: + case <-ctx.Done(): + } + return err +} + +// handleWS upgrades an HTTP request to a WebSocket connection and starts the +// per-connection handler goroutine. +func (s *Server) handleWS(w http.ResponseWriter, r *http.Request) { + c, err := ws.Accept(w, r) + if err != nil { + http.Error(w, "WebSocket upgrade failed", http.StatusBadRequest) + return + } + + // Generate 32-byte nonce for the auth challenge. + nonce := make([]byte, 32) + if _, err := rand.Read(nonce); err != nil { + log.Printf("relay: generate nonce: %v", err) + c.CloseConn() + return + } + + connID := hex.EncodeToString(nonce[:8]) + + h := &conn{ + id: connID, + ws: c, + store: s.store, + global: s.global, + allowlist: s.allowlist, + relayURL: s.cfg.RelayURL, + nonce: nonce, + mgr: subscription.NewManager(), + } + + s.mu.Add(1) + go func() { + defer s.mu.Done() + ctx := r.Context() + h.serve(ctx) + if err := c.CloseConn(); err != nil { + // Ignore close errors — connection may already be gone. + _ = err + } + log.Printf("conn %s: closed", connID) + }() +} + +// generateConnID creates a unique connection identifier for logging. +func generateConnID() (string, error) { + var b [8]byte + if _, err := rand.Read(b[:]); err != nil { + return "", fmt.Errorf("server: generate conn id: %w", err) + } + return hex.EncodeToString(b[:]), nil +} diff --git a/relay/storage/events.go b/relay/storage/events.go new file mode 100644 index 0000000..cf10097 --- /dev/null +++ b/relay/storage/events.go @@ -0,0 +1,192 @@ +package storage + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "axon" +) + +// ErrDuplicate is returned by StoreEvent when the event ID already exists. +var ErrDuplicate = errors.New("storage: duplicate event") + +// StoreEvent persists an event and its tags to the database in a single +// transaction. envelopeBytes is the verbatim msgpack representation used for +// zero-copy fanout. +func (s *Storage) StoreEvent(ctx context.Context, event *axon.Event, envelopeBytes []byte) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("storage: begin tx: %w", err) + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, + `INSERT INTO events (id, pubkey, created_at, kind, envelope_bytes) VALUES (?, ?, ?, ?, ?)`, + event.ID, event.PubKey, event.CreatedAt, event.Kind, envelopeBytes, + ) + if err != nil { + if isDuplicateError(err) { + return ErrDuplicate + } + return fmt.Errorf("storage: insert event: %w", err) + } + + for _, tag := range event.Tags { + if len(tag.Values) == 0 { + continue + } + _, err = tx.ExecContext(ctx, + `INSERT INTO tags (event_id, name, value) VALUES (?, ?, ?)`, + event.ID, tag.Name, tag.Values[0], + ) + if err != nil { + return fmt.Errorf("storage: insert tag: %w", err) + } + } + + return tx.Commit() +} + +// ExistsByID returns true if an event with the given ID is already stored. +func (s *Storage) ExistsByID(ctx context.Context, id []byte) (bool, error) { + var n int + err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM events WHERE id = ?`, id).Scan(&n) + if err != nil && err != sql.ErrNoRows { + return false, fmt.Errorf("storage: exists: %w", err) + } + return n > 0, nil +} + +// QueryEvents executes the given filters against the database using a UNION +// query and returns matching event envelope bytes in descending created_at +// order. The effective LIMIT is the minimum non-zero Limit across all filters. +func (s *Storage) QueryEvents(ctx context.Context, filters []axon.Filter) ([][]byte, error) { + if len(filters) == 0 { + return nil, nil + } + + var unions []string + var args []interface{} + var effectiveLimit int32 + + for _, f := range filters { + var filterArgs []interface{} + clause := buildWhereClause(f, &filterArgs) + sub := fmt.Sprintf( + "SELECT e.envelope_bytes, e.created_at FROM events e WHERE %s", clause) + unions = append(unions, sub) + args = append(args, filterArgs...) + if f.Limit > 0 && (effectiveLimit == 0 || f.Limit < effectiveLimit) { + effectiveLimit = f.Limit + } + } + + query := strings.Join(unions, " UNION ") + " ORDER BY created_at DESC" + if effectiveLimit > 0 { + query += fmt.Sprintf(" LIMIT %d", effectiveLimit) + } + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("storage: query: %w", err) + } + defer rows.Close() + + var results [][]byte + for rows.Next() { + var envelope []byte + var createdAt int64 + if err := rows.Scan(&envelope, &createdAt); err != nil { + return nil, fmt.Errorf("storage: scan: %w", err) + } + results = append(results, envelope) + } + return results, rows.Err() +} + +// buildWhereClause builds the SQL WHERE clause for a single filter, appending +// bind parameters to args. +func buildWhereClause(f axon.Filter, args *[]interface{}) string { + var conditions []string + + if len(f.IDs) > 0 { + conditions = append(conditions, buildBlobPrefixCondition("e.id", f.IDs, args)) + } + + if len(f.Authors) > 0 { + conditions = append(conditions, buildBlobPrefixCondition("e.pubkey", f.Authors, args)) + } + + if len(f.Kinds) > 0 { + placeholders := make([]string, len(f.Kinds)) + for i, k := range f.Kinds { + placeholders[i] = "?" + *args = append(*args, k) + } + conditions = append(conditions, "e.kind IN ("+strings.Join(placeholders, ",")+")") + } + + if f.Since != 0 { + conditions = append(conditions, "e.created_at >= ?") + *args = append(*args, f.Since) + } + + if f.Until != 0 { + conditions = append(conditions, "e.created_at <= ?") + *args = append(*args, f.Until) + } + + for _, tf := range f.Tags { + conditions = append(conditions, buildTagJoinCondition(tf, args)) + } + + if len(conditions) == 0 { + return "1=1" + } + return strings.Join(conditions, " AND ") +} + +// buildBlobPrefixCondition builds an OR condition for prefix-matching a BLOB +// column. Prefix slices of exactly 32 bytes use equality; shorter slices use +// hex(column) LIKE 'HEX%'. +func buildBlobPrefixCondition(column string, prefixes [][]byte, args *[]interface{}) string { + var orConds []string + for _, prefix := range prefixes { + if len(prefix) == 32 { + orConds = append(orConds, column+" = ?") + *args = append(*args, prefix) + } else { + hexPrefix := fmt.Sprintf("%X", prefix) + orConds = append(orConds, fmt.Sprintf("hex(%s) LIKE ?", column)) + *args = append(*args, hexPrefix+"%") + } + } + if len(orConds) == 1 { + return orConds[0] + } + return "(" + strings.Join(orConds, " OR ") + ")" +} + +// buildTagJoinCondition builds an EXISTS sub-select for a TagFilter. +func buildTagJoinCondition(tf axon.TagFilter, args *[]interface{}) string { + if len(tf.Values) == 0 { + *args = append(*args, tf.Name) + return "EXISTS (SELECT 1 FROM tags t WHERE t.event_id = e.id AND t.name = ?)" + } + var orConds []string + for _, v := range tf.Values { + orConds = append(orConds, "EXISTS (SELECT 1 FROM tags t WHERE t.event_id = e.id AND t.name = ? AND t.value = ?)") + *args = append(*args, tf.Name, v) + } + if len(orConds) == 1 { + return orConds[0] + } + return "(" + strings.Join(orConds, " OR ") + ")" +} + +func isDuplicateError(err error) bool { + return err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") +} diff --git a/relay/storage/storage.go b/relay/storage/storage.go new file mode 100644 index 0000000..95b278d --- /dev/null +++ b/relay/storage/storage.go @@ -0,0 +1,86 @@ +// Package storage provides SQLite-backed event persistence for the Axon relay. +package storage + +import ( + "context" + "database/sql" + "fmt" + + _ "modernc.org/sqlite" +) + +// Storage wraps a SQLite database for Axon event persistence. +type Storage struct { + db *sql.DB +} + +// New opens (or creates) the SQLite database at dbPath, applies WAL pragmas, +// and initialises the schema. Call Close when done. +func New(dbPath string) (*Storage, error) { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, fmt.Errorf("storage: open db: %w", err) + } + + // SQLite works best with a single writer. + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(0) + + pragmas := []string{ + "PRAGMA journal_mode=WAL", + "PRAGMA synchronous=NORMAL", + "PRAGMA cache_size=-40960", // ~40 MB (negative = kibibytes) + "PRAGMA temp_store=MEMORY", + "PRAGMA mmap_size=268435456", // 256 MB + "PRAGMA page_size=4096", + "PRAGMA foreign_keys=ON", + "PRAGMA busy_timeout=5000", + } + + for _, p := range pragmas { + if _, err := db.Exec(p); err != nil { + db.Close() + return nil, fmt.Errorf("storage: set pragma %q: %w", p, err) + } + } + + s := &Storage{db: db} + if err := s.initSchema(context.Background()); err != nil { + db.Close() + return nil, fmt.Errorf("storage: init schema: %w", err) + } + return s, nil +} + +// Close closes the underlying database connection. +func (s *Storage) Close() error { + return s.db.Close() +} + +const schema = ` +CREATE TABLE IF NOT EXISTS events ( + id BLOB PRIMARY KEY, + pubkey BLOB NOT NULL, + created_at INTEGER NOT NULL, + kind INTEGER NOT NULL, + envelope_bytes BLOB NOT NULL +) STRICT; + +CREATE TABLE IF NOT EXISTS tags ( + event_id BLOB NOT NULL REFERENCES events(id), + name TEXT NOT NULL, + value TEXT NOT NULL +) STRICT; + +CREATE INDEX IF NOT EXISTS idx_events_pubkey ON events(pubkey, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_events_kind ON events(kind, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_events_created_at ON events(created_at DESC); +CREATE INDEX IF NOT EXISTS idx_tags_name_value ON tags(name, value); +CREATE INDEX IF NOT EXISTS idx_tags_event_id ON tags(event_id); +` + +func (s *Storage) initSchema(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, schema) + return err +} diff --git a/relay/subscription/manager.go b/relay/subscription/manager.go new file mode 100644 index 0000000..d8ba653 --- /dev/null +++ b/relay/subscription/manager.go @@ -0,0 +1,317 @@ +// Package subscription manages in-memory subscriptions and live event fanout. +package subscription + +import ( + "bytes" + "sync" + "time" + + "axon" +) + +// Subscription holds a single client subscription: an ID, a set of filters, +// and a buffered channel for live event delivery. +type Subscription struct { + ID string + Filters []axon.Filter + // Events carries raw msgpack envelope bytes for matched events. + Events chan []byte + done chan struct{} + once sync.Once +} + +// newSubscription creates a subscription with a buffered event channel. +func newSubscription(id string, filters []axon.Filter) *Subscription { + return &Subscription{ + ID: id, + Filters: filters, + Events: make(chan []byte, 100), + done: make(chan struct{}), + } +} + +// Close shuts down the subscription and drains the event channel. +func (s *Subscription) Close() { + s.once.Do(func() { + close(s.done) + close(s.Events) + }) +} + +// IsClosed reports whether the subscription has been closed. +func (s *Subscription) IsClosed() bool { + select { + case <-s.done: + return true + default: + return false + } +} + +// Done returns a channel that is closed when the subscription is closed. +func (s *Subscription) Done() <-chan struct{} { + return s.done +} + +// Manager maintains all active subscriptions for a single connection and fans +// out incoming events to matching subscribers. +type Manager struct { + mu sync.RWMutex + subs map[string]*Subscription +} + +// NewManager returns an empty Manager. +func NewManager() *Manager { + return &Manager{subs: make(map[string]*Subscription)} +} + +// Add registers a new subscription, replacing any existing subscription with +// the same ID. +func (m *Manager) Add(id string, filters []axon.Filter) *Subscription { + m.mu.Lock() + defer m.mu.Unlock() + + // Remove old subscription with same ID if present. + if old, ok := m.subs[id]; ok { + old.Close() + } + + sub := newSubscription(id, filters) + m.subs[id] = sub + return sub +} + +// Remove cancels and deletes a subscription by ID. +func (m *Manager) Remove(id string) { + m.mu.Lock() + defer m.mu.Unlock() + + if sub, ok := m.subs[id]; ok { + sub.Close() + delete(m.subs, id) + } +} + +// CloseAll cancels every subscription held by this manager. +func (m *Manager) CloseAll() { + m.mu.Lock() + defer m.mu.Unlock() + + for id, sub := range m.subs { + sub.Close() + delete(m.subs, id) + } +} + +// Fanout delivers envelopeBytes to all subscriptions whose filters match event. +// The send is non-blocking: if the channel is full the event is dropped for +// that subscriber. +func (m *Manager) Fanout(event *axon.Event, envelopeBytes []byte) { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, sub := range m.subs { + if sub.IsClosed() { + continue + } + if MatchesAnyFilter(event, sub.Filters) { + select { + case sub.Events <- envelopeBytes: + case <-sub.done: + default: + // channel full — drop + } + } + } +} + +// MatchesAnyFilter returns true if event matches at least one filter in filters. +func MatchesAnyFilter(event *axon.Event, filters []axon.Filter) bool { + for i := range filters { + if MatchesFilter(event, &filters[i]) { + return true + } + } + return false +} + +// MatchesFilter returns true if event satisfies all constraints in f. +func MatchesFilter(event *axon.Event, f *axon.Filter) bool { + if len(f.IDs) > 0 { + if !matchesBytesPrefixes(event.ID, f.IDs) { + return false + } + } + + if len(f.Authors) > 0 { + if !matchesBytesPrefixes(event.PubKey, f.Authors) { + return false + } + } + + if len(f.Kinds) > 0 { + found := false + for _, k := range f.Kinds { + if event.Kind == k { + found = true + break + } + } + if !found { + return false + } + } + + if f.Since != 0 && event.CreatedAt < f.Since { + return false + } + + if f.Until != 0 && event.CreatedAt > f.Until { + return false + } + + for _, tf := range f.Tags { + if !eventHasTagMatch(event, tf.Name, tf.Values) { + return false + } + } + + return true +} + +// matchesBytesPrefixes returns true if value has any of the given byte slices +// as a prefix. A prefix of exactly 32 bytes must match exactly. +func matchesBytesPrefixes(value []byte, prefixes [][]byte) bool { + for _, prefix := range prefixes { + if len(prefix) == 0 { + return true + } + if len(prefix) >= len(value) { + if bytes.Equal(value, prefix) { + return true + } + } else { + if bytes.HasPrefix(value, prefix) { + return true + } + } + } + return false +} + +// eventHasTagMatch returns true if event has a tag named name whose first +// value matches any of values. +func eventHasTagMatch(event *axon.Event, name string, values []string) bool { + for _, tag := range event.Tags { + if tag.Name != name { + continue + } + if len(values) == 0 { + return true + } + if len(tag.Values) == 0 { + continue + } + for _, v := range values { + if tag.Values[0] == v { + return true + } + } + } + return false +} + +// GlobalManager is a relay-wide manager that holds subscriptions from all +// connections and supports cross-connection fanout. +type GlobalManager struct { + mu sync.RWMutex + subs map[string]*Subscription // key: "connID:subID" +} + +// NewGlobalManager returns an empty GlobalManager. +func NewGlobalManager() *GlobalManager { + return &GlobalManager{subs: make(map[string]*Subscription)} +} + +// Register adds a subscription under a globally unique key. +func (g *GlobalManager) Register(connID, subID string, sub *Subscription) { + key := connID + ":" + subID + g.mu.Lock() + defer g.mu.Unlock() + if old, ok := g.subs[key]; ok { + old.Close() + } + g.subs[key] = sub +} + +// Unregister removes a subscription. +func (g *GlobalManager) Unregister(connID, subID string) { + key := connID + ":" + subID + g.mu.Lock() + defer g.mu.Unlock() + if sub, ok := g.subs[key]; ok { + sub.Close() + delete(g.subs, key) + } +} + +// UnregisterConn removes all subscriptions for a connection. +func (g *GlobalManager) UnregisterConn(connID string) { + prefix := connID + ":" + g.mu.Lock() + defer g.mu.Unlock() + for key, sub := range g.subs { + if len(key) > len(prefix) && key[:len(prefix)] == prefix { + sub.Close() + delete(g.subs, key) + } + } +} + +// Fanout delivers the event to all matching subscriptions across all connections. +func (g *GlobalManager) Fanout(event *axon.Event, envelopeBytes []byte) { + g.mu.RLock() + defer g.mu.RUnlock() + + for _, sub := range g.subs { + if sub.IsClosed() { + continue + } + if MatchesAnyFilter(event, sub.Filters) { + select { + case sub.Events <- envelopeBytes: + case <-sub.done: + default: + } + } + } +} + +// PurgeExpired removes closed subscriptions from the global map. +// Call periodically to prevent unbounded growth. +func (g *GlobalManager) PurgeExpired() { + g.mu.Lock() + defer g.mu.Unlock() + for key, sub := range g.subs { + if sub.IsClosed() { + delete(g.subs, key) + } + } +} + +// StartPurger launches a background goroutine that periodically removes closed +// subscriptions. +func (g *GlobalManager) StartPurger(interval time.Duration, stop <-chan struct{}) { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + g.PurgeExpired() + case <-stop: + return + } + } + }() +} diff --git a/relay/websocket/websocket.go b/relay/websocket/websocket.go new file mode 100644 index 0000000..cfc3289 --- /dev/null +++ b/relay/websocket/websocket.go @@ -0,0 +1,244 @@ +// Package websocket implements RFC 6455 WebSocket framing without external dependencies. +// Adapted from muxstr's websocket implementation. +package websocket + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" + "time" +) + +const ( + opContinuation = 0x0 + opBinary = 0x2 + opClose = 0x8 + opPing = 0x9 + opPong = 0xA +) + +// Conn is a WebSocket connection. +type Conn struct { + rwc net.Conn + br *bufio.Reader + client bool + mu sync.Mutex +} + +func mask(key [4]byte, data []byte) { + for i := range data { + data[i] ^= key[i%4] + } +} + +func (c *Conn) writeFrame(opcode byte, payload []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + length := len(payload) + header := []byte{0x80 | opcode, 0} // FIN + opcode + + if c.client { + header[1] = 0x80 // mask bit + } + + switch { + case length <= 125: + header[1] |= byte(length) + case length <= 65535: + header[1] |= 126 + ext := make([]byte, 2) + binary.BigEndian.PutUint16(ext, uint16(length)) + header = append(header, ext...) + default: + header[1] |= 127 + ext := make([]byte, 8) + binary.BigEndian.PutUint64(ext, uint64(length)) + header = append(header, ext...) + } + + if c.client { + var key [4]byte + rand.Read(key[:]) + header = append(header, key[:]...) + // mask a copy so we don't modify the caller's slice + masked := make([]byte, len(payload)) + copy(masked, payload) + mask(key, masked) + payload = masked + } + + if _, err := c.rwc.Write(header); err != nil { + return err + } + _, err := c.rwc.Write(payload) + return err +} + +func (c *Conn) readFrame() (fin bool, opcode byte, payload []byte, err error) { + var hdr [2]byte + if _, err = io.ReadFull(c.br, hdr[:]); err != nil { + return + } + + fin = hdr[0]&0x80 != 0 + opcode = hdr[0] & 0x0F + masked := hdr[1]&0x80 != 0 + length := uint64(hdr[1] & 0x7F) + + switch length { + case 126: + var ext [2]byte + if _, err = io.ReadFull(c.br, ext[:]); err != nil { + return + } + length = uint64(binary.BigEndian.Uint16(ext[:])) + case 127: + var ext [8]byte + if _, err = io.ReadFull(c.br, ext[:]); err != nil { + return + } + length = binary.BigEndian.Uint64(ext[:]) + } + + var key [4]byte + if masked { + if _, err = io.ReadFull(c.br, key[:]); err != nil { + return + } + } + + payload = make([]byte, length) + if _, err = io.ReadFull(c.br, payload); err != nil { + return + } + + if masked { + mask(key, payload) + } + return +} + +// Read reads the next complete message from the connection. +// It handles ping frames automatically by sending pong responses. +// It respects context cancellation by setting a read deadline. +func (c *Conn) Read(ctx context.Context) ([]byte, error) { + stop := context.AfterFunc(ctx, func() { + c.rwc.SetReadDeadline(time.Now()) + }) + defer stop() + + var buf []byte + for { + fin, opcode, payload, err := c.readFrame() + if err != nil { + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, err + } + + switch opcode { + case opPing: + c.writeFrame(opPong, payload) + continue + case opClose: + return nil, fmt.Errorf("websocket: close frame received") + case opBinary, opContinuation: + buf = append(buf, payload...) + if fin { + return buf, nil + } + default: + // text or other opcodes — treat payload as binary + buf = append(buf, payload...) + if fin { + return buf, nil + } + } + } +} + +// Write sends a binary frame to the connection. +func (c *Conn) Write(data []byte) error { + return c.writeFrame(opBinary, data) +} + +// Ping sends a WebSocket ping frame. +func (c *Conn) Ping() error { + return c.writeFrame(opPing, nil) +} + +// Close sends a close frame with the given code and reason, then closes the +// underlying connection. +func (c *Conn) Close(code uint16, reason string) error { + payload := make([]byte, 2+len(reason)) + binary.BigEndian.PutUint16(payload, code) + copy(payload[2:], reason) + c.writeFrame(opClose, payload) + return c.rwc.Close() +} + +// CloseConn closes the underlying network connection without sending a close frame. +func (c *Conn) CloseConn() error { + return c.rwc.Close() +} + +// SetReadDeadline sets the read deadline on the underlying connection. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.rwc.SetReadDeadline(t) +} + +var wsGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + +func acceptKey(key string) string { + h := sha1.New() + h.Write([]byte(key)) + h.Write([]byte(wsGUID)) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +// Accept performs the server-side WebSocket handshake, hijacking the HTTP +// connection and returning a Conn ready for framed I/O. +func Accept(w http.ResponseWriter, r *http.Request) (*Conn, error) { + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + return nil, fmt.Errorf("websocket: missing Upgrade header") + } + + key := r.Header.Get("Sec-WebSocket-Key") + if key == "" { + return nil, fmt.Errorf("websocket: missing Sec-WebSocket-Key") + } + + hj, ok := w.(http.Hijacker) + if !ok { + return nil, fmt.Errorf("websocket: response does not support hijacking") + } + + rwc, brw, err := hj.Hijack() + if err != nil { + return nil, err + } + + accept := acceptKey(key) + respStr := "HTTP/1.1 101 Switching Protocols\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Accept: " + accept + "\r\n\r\n" + + if _, err := rwc.Write([]byte(respStr)); err != nil { + rwc.Close() + return nil, err + } + + return &Conn{rwc: rwc, br: brw.Reader, client: false}, nil +} -- cgit v1.2.3