// Package websocket implements RFC 6455 WebSocket framing without external dependencies. // Adapted from muxstr's websocket implementation. package websocket import ( "bufio" "context" "crypto/rand" "crypto/sha1" "crypto/tls" "encoding/base64" "encoding/binary" "fmt" "io" "net" "net/http" "net/url" "strings" "sync" "time" ) const ( opContinuation = 0x0 opBinary = 0x2 opClose = 0x8 opPing = 0x9 opPong = 0xA ) // Conn is a WebSocket connection. type Conn struct { rwc net.Conn br *bufio.Reader client bool mu sync.Mutex } func mask(key [4]byte, data []byte) { for i := range data { data[i] ^= key[i%4] } } func (c *Conn) writeFrame(opcode byte, payload []byte) error { c.mu.Lock() defer c.mu.Unlock() length := len(payload) header := []byte{0x80 | opcode, 0} // FIN + opcode if c.client { header[1] = 0x80 // mask bit } switch { case length <= 125: header[1] |= byte(length) case length <= 65535: header[1] |= 126 ext := make([]byte, 2) binary.BigEndian.PutUint16(ext, uint16(length)) header = append(header, ext...) default: header[1] |= 127 ext := make([]byte, 8) binary.BigEndian.PutUint64(ext, uint64(length)) header = append(header, ext...) } if c.client { var key [4]byte rand.Read(key[:]) header = append(header, key[:]...) // mask a copy so we don't modify the caller's slice masked := make([]byte, len(payload)) copy(masked, payload) mask(key, masked) payload = masked } if _, err := c.rwc.Write(header); err != nil { return err } _, err := c.rwc.Write(payload) return err } func (c *Conn) readFrame() (fin bool, opcode byte, payload []byte, err error) { var hdr [2]byte if _, err = io.ReadFull(c.br, hdr[:]); err != nil { return } fin = hdr[0]&0x80 != 0 opcode = hdr[0] & 0x0F masked := hdr[1]&0x80 != 0 length := uint64(hdr[1] & 0x7F) switch length { case 126: var ext [2]byte if _, err = io.ReadFull(c.br, ext[:]); err != nil { return } length = uint64(binary.BigEndian.Uint16(ext[:])) case 127: var ext [8]byte if _, err = io.ReadFull(c.br, ext[:]); err != nil { return } length = binary.BigEndian.Uint64(ext[:]) } var key [4]byte if masked { if _, err = io.ReadFull(c.br, key[:]); err != nil { return } } payload = make([]byte, length) if _, err = io.ReadFull(c.br, payload); err != nil { return } if masked { mask(key, payload) } return } // Read reads the next complete message from the connection. // It handles ping frames automatically by sending pong responses. // It respects context cancellation by setting a read deadline. func (c *Conn) Read(ctx context.Context) ([]byte, error) { stop := context.AfterFunc(ctx, func() { c.rwc.SetReadDeadline(time.Now()) }) defer stop() var buf []byte for { fin, opcode, payload, err := c.readFrame() if err != nil { if ctx.Err() != nil { return nil, ctx.Err() } return nil, err } switch opcode { case opPing: c.writeFrame(opPong, payload) continue case opClose: return nil, fmt.Errorf("websocket: close frame received") case opBinary, opContinuation: buf = append(buf, payload...) if fin { return buf, nil } default: // text or other opcodes — treat payload as binary buf = append(buf, payload...) if fin { return buf, nil } } } } // Write sends a binary frame to the connection. func (c *Conn) Write(data []byte) error { return c.writeFrame(opBinary, data) } // Ping sends a WebSocket ping frame. func (c *Conn) Ping() error { return c.writeFrame(opPing, nil) } // Close sends a close frame with the given code and reason, then closes the // underlying connection. func (c *Conn) Close(code uint16, reason string) error { payload := make([]byte, 2+len(reason)) binary.BigEndian.PutUint16(payload, code) copy(payload[2:], reason) c.writeFrame(opClose, payload) return c.rwc.Close() } // CloseConn closes the underlying network connection without sending a close frame. func (c *Conn) CloseConn() error { return c.rwc.Close() } // SetReadDeadline sets the read deadline on the underlying connection. func (c *Conn) SetReadDeadline(t time.Time) error { return c.rwc.SetReadDeadline(t) } var wsGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" func acceptKey(key string) string { h := sha1.New() h.Write([]byte(key)) h.Write([]byte(wsGUID)) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } // Dial connects to a WebSocket server at rawURL and performs the client-side // RFC 6455 handshake. Supports ws:// and wss:// schemes. func Dial(rawURL string) (*Conn, error) { u, err := url.Parse(rawURL) if err != nil { return nil, fmt.Errorf("websocket: parse url: %w", err) } host := u.Host var netConn net.Conn switch u.Scheme { case "ws": if !strings.Contains(host, ":") { host += ":80" } netConn, err = net.Dial("tcp", host) case "wss": if !strings.Contains(host, ":") { host += ":443" } netConn, err = tls.Dial("tcp", host, &tls.Config{ServerName: u.Hostname()}) default: return nil, fmt.Errorf("websocket: unsupported scheme %q (use ws:// or wss://)", u.Scheme) } if err != nil { return nil, fmt.Errorf("websocket: dial %s: %w", host, err) } // Generate a random 16-byte key and base64-encode it. var keyBytes [16]byte if _, err := rand.Read(keyBytes[:]); err != nil { netConn.Close() return nil, fmt.Errorf("websocket: generate key: %w", err) } key := base64.StdEncoding.EncodeToString(keyBytes[:]) path := u.RequestURI() if path == "" { path = "/" } req := "GET " + path + " HTTP/1.1\r\n" + "Host: " + u.Host + "\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: " + key + "\r\n" + "Sec-WebSocket-Version: 13\r\n\r\n" if _, err := netConn.Write([]byte(req)); err != nil { netConn.Close() return nil, fmt.Errorf("websocket: send handshake: %w", err) } br := bufio.NewReader(netConn) resp, err := http.ReadResponse(br, nil) if err != nil { netConn.Close() return nil, fmt.Errorf("websocket: read handshake response: %w", err) } resp.Body.Close() if resp.StatusCode != 101 { netConn.Close() return nil, fmt.Errorf("websocket: server returned status %d, want 101", resp.StatusCode) } if resp.Header.Get("Sec-WebSocket-Accept") != acceptKey(key) { netConn.Close() return nil, fmt.Errorf("websocket: bad Sec-WebSocket-Accept header") } return &Conn{rwc: netConn, br: br, client: true}, nil } // Accept performs the server-side WebSocket handshake, hijacking the HTTP // connection and returning a Conn ready for framed I/O. func Accept(w http.ResponseWriter, r *http.Request) (*Conn, error) { if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { return nil, fmt.Errorf("websocket: missing Upgrade header") } key := r.Header.Get("Sec-WebSocket-Key") if key == "" { return nil, fmt.Errorf("websocket: missing Sec-WebSocket-Key") } hj, ok := w.(http.Hijacker) if !ok { return nil, fmt.Errorf("websocket: response does not support hijacking") } rwc, brw, err := hj.Hijack() if err != nil { return nil, err } accept := acceptKey(key) respStr := "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + accept + "\r\n\r\n" if _, err := rwc.Write([]byte(respStr)); err != nil { rwc.Close() return nil, err } return &Conn{rwc: rwc, br: brw.Reader, client: false}, nil }