package main import ( "context" "crypto/rand" "encoding/hex" "fmt" "log" "net/http" "sync" "axon/relay/storage" "axon/relay/subscription" ws "axon/relay/websocket" ) // Server is the HTTP + WebSocket server for the Axon relay. type Server struct { cfg Config allowlist [][]byte store *storage.Storage global *subscription.GlobalManager mu sync.WaitGroup httpSrv *http.Server } // NewServer creates a Server from the given config. func NewServer(cfg Config, allowlist [][]byte, store *storage.Storage, global *subscription.GlobalManager) *Server { return &Server{ cfg: cfg, allowlist: allowlist, store: store, global: global, } } // Start configures the HTTP server and starts listening. Call Shutdown to stop. func (s *Server) Start() error { mux := http.NewServeMux() mux.HandleFunc("/", s.handleWS) s.httpSrv = &http.Server{ Addr: s.cfg.Addr, Handler: mux, } log.Printf("relay: listening on %s", s.cfg.Addr) return s.httpSrv.ListenAndServe() } // Shutdown gracefully stops the server and waits for all connections to drain. func (s *Server) Shutdown(ctx context.Context) error { err := s.httpSrv.Shutdown(ctx) // Wait for all handler goroutines to finish. done := make(chan struct{}) go func() { s.mu.Wait() close(done) }() select { case <-done: case <-ctx.Done(): } return err } // handleWS upgrades an HTTP request to a WebSocket connection and starts the // per-connection handler goroutine. func (s *Server) handleWS(w http.ResponseWriter, r *http.Request) { c, err := ws.Accept(w, r) if err != nil { http.Error(w, "WebSocket upgrade failed", http.StatusBadRequest) return } // Generate 32-byte nonce for the auth challenge. nonce := make([]byte, 32) if _, err := rand.Read(nonce); err != nil { log.Printf("relay: generate nonce: %v", err) c.CloseConn() return } connID := hex.EncodeToString(nonce[:8]) h := &conn{ id: connID, ws: c, store: s.store, global: s.global, allowlist: s.allowlist, relayURL: s.cfg.RelayURL, nonce: nonce, mgr: subscription.NewManager(), } s.mu.Add(1) go func() { defer s.mu.Done() ctx := r.Context() h.serve(ctx) if err := c.CloseConn(); err != nil { // Ignore close errors — connection may already be gone. _ = err } log.Printf("conn %s: closed", connID) }() } // generateConnID creates a unique connection identifier for logging. func generateConnID() (string, error) { var b [8]byte if _, err := rand.Read(b[:]); err != nil { return "", fmt.Errorf("server: generate conn id: %w", err) } return hex.EncodeToString(b[:]), nil }