package websocket import ( "bufio" "context" "crypto/rand" "crypto/sha1" "crypto/tls" "encoding/base64" "encoding/binary" "fmt" "io" "net" "net/http" "net/url" "strings" "sync" "time" ) type MessageType int const MessageText MessageType = 1 type StatusCode int const StatusNormalClosure StatusCode = 1000 const ( opText = 0x1 opClose = 0x8 opPing = 0x9 opPong = 0xA ) type Conn struct { rwc net.Conn br *bufio.Reader client bool mu sync.Mutex } func mask(key [4]byte, data []byte) { for i := range data { data[i] ^= key[i%4] } } func (c *Conn) writeFrame(opcode byte, payload []byte) error { c.mu.Lock() defer c.mu.Unlock() length := len(payload) header := []byte{0x80 | opcode, 0} // FIN + opcode if c.client { header[1] = 0x80 // mask bit } switch { case length <= 125: header[1] |= byte(length) case length <= 65535: header[1] |= 126 ext := make([]byte, 2) binary.BigEndian.PutUint16(ext, uint16(length)) header = append(header, ext...) default: header[1] |= 127 ext := make([]byte, 8) binary.BigEndian.PutUint64(ext, uint64(length)) header = append(header, ext...) } if c.client { var key [4]byte rand.Read(key[:]) header = append(header, key[:]...) mask(key, payload) } if _, err := c.rwc.Write(header); err != nil { return err } _, err := c.rwc.Write(payload) return err } func (c *Conn) readFrame() (fin bool, opcode byte, payload []byte, err error) { var hdr [2]byte if _, err = io.ReadFull(c.br, hdr[:]); err != nil { return } fin = hdr[0]&0x80 != 0 opcode = hdr[0] & 0x0F masked := hdr[1]&0x80 != 0 length := uint64(hdr[1] & 0x7F) switch length { case 126: var ext [2]byte if _, err = io.ReadFull(c.br, ext[:]); err != nil { return } length = uint64(binary.BigEndian.Uint16(ext[:])) case 127: var ext [8]byte if _, err = io.ReadFull(c.br, ext[:]); err != nil { return } length = binary.BigEndian.Uint64(ext[:]) } var key [4]byte if masked { if _, err = io.ReadFull(c.br, key[:]); err != nil { return } } payload = make([]byte, length) if _, err = io.ReadFull(c.br, payload); err != nil { return } if masked { mask(key, payload) } return } func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { stop := context.AfterFunc(ctx, func() { c.rwc.SetReadDeadline(time.Now()) }) defer stop() var buf []byte for { fin, opcode, payload, err := c.readFrame() if err != nil { if ctx.Err() != nil { return 0, nil, ctx.Err() } return 0, nil, err } switch opcode { case opPing: c.writeFrame(opPong, payload) continue case opClose: return 0, nil, fmt.Errorf("websocket: close frame received") case opText, 0x0: // text or continuation buf = append(buf, payload...) if fin { return MessageText, buf, nil } default: buf = append(buf, payload...) if fin { return MessageText, buf, nil } } } } func (c *Conn) Write(ctx context.Context, typ MessageType, data []byte) error { return c.writeFrame(byte(typ), data) } func (c *Conn) Close(code StatusCode, reason string) error { payload := make([]byte, 2+len(reason)) binary.BigEndian.PutUint16(payload, uint16(code)) copy(payload[2:], reason) c.writeFrame(opClose, payload) return c.rwc.Close() } var wsGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" func acceptKey(key string) string { h := sha1.New() h.Write([]byte(key)) h.Write([]byte(wsGUID)) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } func Dial(ctx context.Context, rawURL string) (*Conn, error) { u, err := url.Parse(rawURL) if err != nil { return nil, err } host := u.Hostname() port := u.Port() useTLS := u.Scheme == "wss" if port == "" { if useTLS { port = "443" } else { port = "80" } } addr := net.JoinHostPort(host, port) rwc, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) if err != nil { return nil, err } if useTLS { tc := tls.Client(rwc, &tls.Config{ServerName: host}) if err := tc.HandshakeContext(ctx); err != nil { rwc.Close() return nil, err } rwc = tc } br := bufio.NewReader(rwc) var keyBytes [16]byte rand.Read(keyBytes[:]) key := base64.StdEncoding.EncodeToString(keyBytes[:]) path := u.RequestURI() reqStr := "GET " + path + " HTTP/1.1\r\n" + "Host: " + host + "\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: " + key + "\r\n" + "Sec-WebSocket-Version: 13\r\n\r\n" if _, err := rwc.Write([]byte(reqStr)); err != nil { rwc.Close() return nil, err } req := &http.Request{Method: "GET"} resp, err := http.ReadResponse(br, req) if err != nil { rwc.Close() return nil, err } resp.Body.Close() if resp.StatusCode != 101 { rwc.Close() return nil, fmt.Errorf("websocket: bad handshake status %d", resp.StatusCode) } got := resp.Header.Get("Sec-WebSocket-Accept") want := acceptKey(key) if got != want { rwc.Close() return nil, fmt.Errorf("websocket: invalid Sec-WebSocket-Accept") } return &Conn{rwc: rwc, br: br, client: true}, nil } func Accept(w http.ResponseWriter, r *http.Request) (*Conn, error) { if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { return nil, fmt.Errorf("websocket: missing Upgrade header") } key := r.Header.Get("Sec-WebSocket-Key") if key == "" { return nil, fmt.Errorf("websocket: missing Sec-WebSocket-Key") } hj, ok := w.(http.Hijacker) if !ok { return nil, fmt.Errorf("websocket: response does not support hijacking") } rwc, brw, err := hj.Hijack() if err != nil { return nil, err } accept := acceptKey(key) respStr := "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + accept + "\r\n\r\n" if _, err := rwc.Write([]byte(respStr)); err != nil { rwc.Close() return nil, err } return &Conn{rwc: rwc, br: brw.Reader, client: false}, nil }