aboutsummaryrefslogtreecommitdiffstats
path: root/relay/server.go
blob: 9f716bfb313e941ef83b608ccc4e100f82403ee1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package main

import (
	"context"
	"crypto/rand"
	"encoding/hex"
	"fmt"
	"log"
	"net/http"
	"sync"

	"code.northwest.io/axon/relay/storage"
	"code.northwest.io/axon/relay/subscription"
	ws "code.northwest.io/axon/relay/websocket"
)

// Server is the HTTP + WebSocket server for the Axon relay.
type Server struct {
	cfg       Config
	allowlist [][]byte
	store     *storage.Storage
	global    *subscription.GlobalManager

	mu      sync.WaitGroup
	httpSrv *http.Server
}

// NewServer creates a Server from the given config.
func NewServer(cfg Config, allowlist [][]byte, store *storage.Storage, global *subscription.GlobalManager) *Server {
	return &Server{
		cfg:       cfg,
		allowlist: allowlist,
		store:     store,
		global:    global,
	}
}

// Start configures the HTTP server and starts listening. Call Shutdown to stop.
func (s *Server) Start() error {
	mux := http.NewServeMux()
	mux.HandleFunc("/", s.handleWS)

	s.httpSrv = &http.Server{
		Addr:    s.cfg.Addr,
		Handler: mux,
	}

	log.Printf("relay: listening on %s", s.cfg.Addr)
	return s.httpSrv.ListenAndServe()
}

// Shutdown gracefully stops the server and waits for all connections to drain.
func (s *Server) Shutdown(ctx context.Context) error {
	err := s.httpSrv.Shutdown(ctx)
	// Wait for all handler goroutines to finish.
	done := make(chan struct{})
	go func() {
		s.mu.Wait()
		close(done)
	}()
	select {
	case <-done:
	case <-ctx.Done():
	}
	return err
}

// handleWS upgrades an HTTP request to a WebSocket connection and starts the
// per-connection handler goroutine.
func (s *Server) handleWS(w http.ResponseWriter, r *http.Request) {
	c, err := ws.Accept(w, r)
	if err != nil {
		http.Error(w, "WebSocket upgrade failed", http.StatusBadRequest)
		return
	}

	// Generate 32-byte nonce for the auth challenge.
	nonce := make([]byte, 32)
	if _, err := rand.Read(nonce); err != nil {
		log.Printf("relay: generate nonce: %v", err)
		c.CloseConn()
		return
	}

	connID := hex.EncodeToString(nonce[:8])

	h := &conn{
		id:        connID,
		ws:        c,
		store:     s.store,
		global:    s.global,
		allowlist: s.allowlist,
		relayURL:  s.cfg.RelayURL,
		nonce:     nonce,
		mgr:       subscription.NewManager(),
	}

	s.mu.Add(1)
	go func() {
		defer s.mu.Done()
		// r.Context() is cancelled by the HTTP server when Hijack is called,
		// so we use a fresh context. The connection manages its own lifecycle
		// via the ping loop and WebSocket close frames.
		ctx := context.Background()
		h.serve(ctx)
		if err := c.CloseConn(); err != nil {
			// Ignore close errors — connection may already be gone.
			_ = err
		}
		log.Printf("conn %s: closed", connID)
	}()
}

// generateConnID creates a unique connection identifier for logging.
func generateConnID() (string, error) {
	var b [8]byte
	if _, err := rand.Read(b[:]); err != nil {
		return "", fmt.Errorf("server: generate conn id: %w", err)
	}
	return hex.EncodeToString(b[:]), nil
}