From 96dc2c1c8164b226cf0e2673454179b0b97cabb2 Mon Sep 17 00:00:00 2001 From: keepalivesrc <54171259+keepalivesrc@users.noreply.github.com> Date: Wed, 16 Oct 2019 01:14:01 -0700 Subject: [PATCH 1/3] websocket Read Limit Fix This fix addresses a potential denial-of-service (DoS) vector that can cause an integer overflow in the presence of malicious WebSocket frames. The fix adds additional checks against the remaining bytes on a connection, as well as a test to prevent regression. Credit to Max Justicz (https://justi.cz/) for discovering and reporting this, as well as providing a robust PoC and review. * bugfix: fix DoS vector caused by readLimit bypass * bugfix: payload length 127 should read bytes as uint64 * bugfix: defend against readLength overflows --- external/github.com/gorilla/websocket/conn.go | 58 +++++++++++++++---- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/external/github.com/gorilla/websocket/conn.go b/external/github.com/gorilla/websocket/conn.go index 55744a3e..25a14581 100644 --- a/external/github.com/gorilla/websocket/conn.go +++ b/external/github.com/gorilla/websocket/conn.go @@ -260,10 +260,12 @@ type Conn struct { newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields - reader io.ReadCloser // the current reader returned to the application - readErr error - br *bufio.Reader - readRemaining int64 // bytes remaining in current frame. + reader io.ReadCloser // the current reader returned to the application + readErr error + br *bufio.Reader + // bytes remaining in current frame. + // set setReadRemaining to safely update this value and prevent overflow + readRemaining int64 readFinal bool // true the current message has more frames. readLength int64 // Message size. readLimit int64 // Maximum message size. @@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, return c } +// setReadRemaining tracks the number of bytes remaining on the connection. If n +// overflows, an ErrReadLimit is returned. +func (c *Conn) setReadRemaining(n int64) error { + if n < 0 { + return ErrReadLimit + } + + c.readRemaining = n + return nil +} + // Subprotocol returns the negotiated protocol for the connection. func (c *Conn) Subprotocol() string { return c.subprotocol @@ -770,7 +783,7 @@ func (c *Conn) advanceFrame() (int, error) { final := p[0]&finalBit != 0 frameType := int(p[0] & 0xf) mask := p[1]&maskBit != 0 - c.readRemaining = int64(p[1] & 0x7f) + c.setReadRemaining(int64(p[1] & 0x7f)) c.readDecompress = false if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { @@ -804,7 +817,17 @@ func (c *Conn) advanceFrame() (int, error) { return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) } - // 3. Read and parse frame length. + // 3. Read and parse frame length as per + // https://tools.ietf.org/html/rfc6455#section-5.2 + // + // The length of the "Payload data", in bytes: if 0-125, that is the payload + // length. + // - If 126, the following 2 bytes interpreted as a 16-bit unsigned + // integer are the payload length. + // - If 127, the following 8 bytes interpreted as + // a 64-bit unsigned integer (the most significant bit MUST be 0) are the + // payload length. Multibyte length quantities are expressed in network byte + // order. switch c.readRemaining { case 126: @@ -812,13 +835,19 @@ func (c *Conn) advanceFrame() (int, error) { if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint16(p)) + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { + return noFrame, err + } case 127: p, err := c.read(8) if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint64(p)) + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { + return noFrame, err + } } // 4. Handle frame masking. @@ -841,6 +870,12 @@ func (c *Conn) advanceFrame() (int, error) { if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { c.readLength += c.readRemaining + // Don't allow readLength to overflow in the presence of a large readRemaining + // counter. + if c.readLength < 0 { + return noFrame, ErrReadLimit + } + if c.readLimit > 0 && c.readLength > c.readLimit { c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) return noFrame, ErrReadLimit @@ -854,7 +889,7 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { payload, err = c.read(int(c.readRemaining)) - c.readRemaining = 0 + c.setReadRemaining(0) if err != nil { return noFrame, err } @@ -927,6 +962,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { c.readErr = hideTempErr(err) break } + if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} c.reader = c.messageReader @@ -967,7 +1003,9 @@ func (r *messageReader) Read(b []byte) (int, error) { if c.isServer { c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) } - c.readRemaining -= int64(n) + rem := c.readRemaining + rem -= int64(n) + c.setReadRemaining(rem) if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } From 3b2d63d8d3623dac09f173d249f74c33c8d64de0 Mon Sep 17 00:00:00 2001 From: keepalivesrc <54171259+keepalivesrc@users.noreply.github.com> Date: Wed, 16 Oct 2019 16:52:41 -0700 Subject: [PATCH 2/3] update TestReadLimit sub-test --- .../github.com/gorilla/websocket/conn_test.go | 696 ++++++++++++++++++ 1 file changed, 696 insertions(+) create mode 100644 external/github.com/gorilla/websocket/conn_test.go diff --git a/external/github.com/gorilla/websocket/conn_test.go b/external/github.com/gorilla/websocket/conn_test.go new file mode 100644 index 00000000..63249bd3 --- /dev/null +++ b/external/github.com/gorilla/websocket/conn_test.go @@ -0,0 +1,696 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "reflect" + "sync" + "testing" + "testing/iotest" + "time" +) + +var _ net.Error = errWriteTimeout + +type fakeNetConn struct { + io.Reader + io.Writer +} + +func (c fakeNetConn) Close() error { return nil } +func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } +func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } +func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type fakeAddr int + +var ( + localAddr = fakeAddr(1) + remoteAddr = fakeAddr(2) +) + +func (a fakeAddr) Network() string { + return "net" +} + +func (a fakeAddr) String() string { + return "str" +} + +// newTestConn creates a connnection backed by a fake network connection using +// default values for buffering. +func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { + return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) +} + +func TestFraming(t *testing.T) { + frameSizes := []int{ + 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, + // 65536, 65537 + } + var readChunkers = []struct { + name string + f func(io.Reader) io.Reader + }{ + {"half", iotest.HalfReader}, + {"one", iotest.OneByteReader}, + {"asis", func(r io.Reader) io.Reader { return r }}, + } + writeBuf := make([]byte, 65537) + for i := range writeBuf { + writeBuf[i] = byte(i) + } + var writers = []struct { + name string + f func(w io.Writer, n int) (int, error) + }{ + {"iocopy", func(w io.Writer, n int) (int, error) { + nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) + return int(nn), err + }}, + {"write", func(w io.Writer, n int) (int, error) { + return w.Write(writeBuf[:n]) + }}, + {"string", func(w io.Writer, n int) (int, error) { + return io.WriteString(w, string(writeBuf[:n])) + }}, + } + + for _, compress := range []bool{false, true} { + for _, isServer := range []bool{true, false} { + for _, chunker := range readChunkers { + + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(chunker.f(&connBuf), nil, !isServer) + //if compress { + // wc.newCompressionWriter = compressNoContextTakeover + // rc.newDecompressionReader = decompressNoContextTakeover + //} + for _, n := range frameSizes { + for _, writer := range writers { + name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + nn, err := writer.f(w, n) + if err != nil || nn != n { + t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) + continue + } + err = w.Close() + if err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + + opCode, r, err := rc.NextReader() + if err != nil || opCode != TextMessage { + t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) + continue + } + + t.Logf("frame size: %d", n) + rbuf, err := ioutil.ReadAll(r) + if err != nil { + t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) + continue + } + + if len(rbuf) != n { + t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) + continue + } + + for i, b := range rbuf { + if byte(i) != b { + t.Errorf("%s: bad byte at offset %d", name, i) + break + } + } + } + } + } + } + } +} + +func TestControl(t *testing.T) { + const message = "this is a ping/pong messsage" + for _, isServer := range []bool{true, false} { + for _, isWriteControl := range []bool{true, false} { + name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(&connBuf, nil, !isServer) + if isWriteControl { + wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + } else { + w, err := wc.NextWriter(PongMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + if _, err := w.Write([]byte(message)); err != nil { + t.Errorf("%s: w.Write() returned %v", name, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + var actualMessage string + rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) + rc.NextReader() + if actualMessage != message { + t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) + continue + } + } + } + } +} + +// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. +type simpleBufferPool struct { + v interface{} +} + +func (p *simpleBufferPool) Get() interface{} { + v := p.v + p.v = nil + return v +} + +func (p *simpleBufferPool) Put(v interface{}) { + p.v = v +} + +func TestWriteBufferPool(t *testing.T) { + const message = "Now is the time for all good people to come to the aid of the party." + + var buf bytes.Buffer + var pool simpleBufferPool + rc := newTestConn(&buf, nil, false) + + // Specify writeBufferSize smaller than message size to ensure that pooling + // works with fragmented messages. + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after create") + } + + // Part 1: test NextWriter/Write/Close + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, message); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err != nil { + t.Fatalf("w.Close() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after w.Close()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + + // Part 2: Test WriteMessage. + + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after wc.WriteMessage()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool after WriteMessage") + } + + opCode, p, err = rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } +} + +// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. +func TestWriteBufferPoolSync(t *testing.T) { + var buf bytes.Buffer + var pool sync.Pool + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) + rc := newTestConn(&buf, nil, false) + + const message = "Hello World!" + for i := 0; i < 3; i++ { + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + } +} + +// errorWriter is an io.Writer than returns an error on all writes. +type errorWriter struct{} + +func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } + +// TestWriteBufferPoolError ensures that buffer is returned to pool after error +// on write. +func TestWriteBufferPoolError(t *testing.T) { + + // Part 1: Test NextWriter/Write/Close + + var pool simpleBufferPool + wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, "Hello"); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err == nil { + t.Fatalf("w.Close() did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + // Part 2: Test WriteMessage + + wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { + t.Fatalf("wc.WriteMessage did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } +} + +func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { + const bufSize = 512 + + expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize+bufSize/2)) + wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) + } + _, _, err = rc.NextReader() + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) + } +} + +func TestEOFWithinFrame(t *testing.T) { + const bufSize = 64 + + for n := 0; ; n++ { + var b bytes.Buffer + wc := newTestConn(nil, &b, false) + rc := newTestConn(&b, nil, true) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize)) + w.Close() + + if n >= b.Len() { + break + } + b.Truncate(n) + + op, r, err := rc.NextReader() + if err == errUnexpectedEOF { + continue + } + if op != BinaryMessage || err != nil { + t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) + } + } +} + +func TestEOFBeforeFinalFrame(t *testing.T) { + const bufSize = 512 + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize+bufSize/2)) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) + } +} + +func TestWriteAfterMessageWriterClose(t *testing.T) { + wc := newTestConn(nil, &bytes.Buffer{}, false) + w, _ := wc.NextWriter(BinaryMessage) + io.WriteString(w, "hello") + if err := w.Close(); err != nil { + t.Fatalf("unxpected error closing message writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } + + w, _ = wc.NextWriter(BinaryMessage) + io.WriteString(w, "hello") + + // close w by getting next writer + _, err := wc.NextWriter(BinaryMessage) + if err != nil { + t.Fatalf("unexpected error getting next writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } +} + +func TestReadLimit(t *testing.T) { + t.Run("Test ReadLimit is enforced", func(t *testing.T) { + const readLimit = 512 + message := make([]byte, readLimit+1) + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // Send message at the limit with interleaved pong. + w, _ := wc.NextWriter(BinaryMessage) + w.Write(message[:readLimit-1]) + wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + w.Write(message[:1]) + w.Close() + + // Send message larger than the limit. + wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + + op, _, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("2: NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != ErrReadLimit { + t.Fatalf("io.Copy() returned %v", err) + } + }) + + t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { + const readLimit = 1 + + var b1, b2 bytes.Buffer + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // First, send a non-final binary message + b1.Write([]byte("\x02\x81")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // First payload + b1.Write([]byte("A")) + + // Next, send a negative-length, non-final continuation frame + b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Next, send a too long, final continuation frame + b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Too-long payload + b1.Write([]byte("BCDEF")) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + + var buf [10]byte + var read int + n, err := r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + n, err = r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + if err == nil && read > readLimit { + t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) + } + }) +} + +func TestAddrs(t *testing.T) { + c := newTestConn(nil, nil, true) + if c.LocalAddr() != localAddr { + t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) + } + if c.RemoteAddr() != remoteAddr { + t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) + } +} + +func TestUnderlyingConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024, nil, nil, nil) + ul := c.UnderlyingConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} + +func TestBufioReadBytes(t *testing.T) { + // Test calling bufio.ReadBytes for value longer than read buffer size. + + m := make([]byte, 512) + m[len(m)-1] = '\n' + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(m) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + + br := bufio.NewReader(r) + p, err := br.ReadBytes('\n') + if err != nil { + t.Fatalf("ReadBytes() returned %v", err) + } + if len(p) != len(m) { + t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) + } +} + +var closeErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestCloseError(t *testing.T) { + for _, tt := range closeErrorTests { + ok := IsCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +var unexpectedCloseErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestUnexpectedCloseErrors(t *testing.T) { + for _, tt := range unexpectedCloseErrorTests { + ok := IsUnexpectedCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +type blockingWriter struct { + c1, c2 chan struct{} +} + +func (w blockingWriter) Write(p []byte) (int, error) { + // Allow main to continue + close(w.c1) + // Wait for panic in main + <-w.c2 + return len(p), nil +} + +func TestConcurrentWritePanic(t *testing.T) { + w := blockingWriter{make(chan struct{}), make(chan struct{})} + c := newTestConn(nil, w, false) + go func() { + c.WriteMessage(TextMessage, []byte{}) + }() + + // wait for goroutine to block in write. + <-w.c1 + + defer func() { + close(w.c2) + if v := recover(); v != nil { + return + } + }() + + c.WriteMessage(TextMessage, []byte{}) + t.Fatal("should not get here") +} + +type failingReader struct{} + +func (r failingReader) Read(p []byte) (int, error) { + return 0, io.EOF +} + +func TestFailedConnectionReadPanic(t *testing.T) { + c := newTestConn(failingReader{}, nil, false) + + defer func() { + if v := recover(); v != nil { + return + } + }() + + for i := 0; i < 20000; i++ { + c.ReadMessage() + } + t.Fatal("should not get here") +} From 01c7bba529ab7e26645dee811f992b8f4f7f8e9d Mon Sep 17 00:00:00 2001 From: keepalivesrc <54171259+keepalivesrc@users.noreply.github.com> Date: Wed, 16 Oct 2019 16:58:07 -0700 Subject: [PATCH 3/3] Delete conn_test.go --- .../github.com/gorilla/websocket/conn_test.go | 696 ------------------ 1 file changed, 696 deletions(-) delete mode 100644 external/github.com/gorilla/websocket/conn_test.go diff --git a/external/github.com/gorilla/websocket/conn_test.go b/external/github.com/gorilla/websocket/conn_test.go deleted file mode 100644 index 63249bd3..00000000 --- a/external/github.com/gorilla/websocket/conn_test.go +++ /dev/null @@ -1,696 +0,0 @@ -// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package websocket - -import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - "io/ioutil" - "net" - "reflect" - "sync" - "testing" - "testing/iotest" - "time" -) - -var _ net.Error = errWriteTimeout - -type fakeNetConn struct { - io.Reader - io.Writer -} - -func (c fakeNetConn) Close() error { return nil } -func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } -func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } -func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } -func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } -func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } - -type fakeAddr int - -var ( - localAddr = fakeAddr(1) - remoteAddr = fakeAddr(2) -) - -func (a fakeAddr) Network() string { - return "net" -} - -func (a fakeAddr) String() string { - return "str" -} - -// newTestConn creates a connnection backed by a fake network connection using -// default values for buffering. -func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { - return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) -} - -func TestFraming(t *testing.T) { - frameSizes := []int{ - 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, - // 65536, 65537 - } - var readChunkers = []struct { - name string - f func(io.Reader) io.Reader - }{ - {"half", iotest.HalfReader}, - {"one", iotest.OneByteReader}, - {"asis", func(r io.Reader) io.Reader { return r }}, - } - writeBuf := make([]byte, 65537) - for i := range writeBuf { - writeBuf[i] = byte(i) - } - var writers = []struct { - name string - f func(w io.Writer, n int) (int, error) - }{ - {"iocopy", func(w io.Writer, n int) (int, error) { - nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) - return int(nn), err - }}, - {"write", func(w io.Writer, n int) (int, error) { - return w.Write(writeBuf[:n]) - }}, - {"string", func(w io.Writer, n int) (int, error) { - return io.WriteString(w, string(writeBuf[:n])) - }}, - } - - for _, compress := range []bool{false, true} { - for _, isServer := range []bool{true, false} { - for _, chunker := range readChunkers { - - var connBuf bytes.Buffer - wc := newTestConn(nil, &connBuf, isServer) - rc := newTestConn(chunker.f(&connBuf), nil, !isServer) - //if compress { - // wc.newCompressionWriter = compressNoContextTakeover - // rc.newDecompressionReader = decompressNoContextTakeover - //} - for _, n := range frameSizes { - for _, writer := range writers { - name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) - - w, err := wc.NextWriter(TextMessage) - if err != nil { - t.Errorf("%s: wc.NextWriter() returned %v", name, err) - continue - } - nn, err := writer.f(w, n) - if err != nil || nn != n { - t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) - continue - } - err = w.Close() - if err != nil { - t.Errorf("%s: w.Close() returned %v", name, err) - continue - } - - opCode, r, err := rc.NextReader() - if err != nil || opCode != TextMessage { - t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) - continue - } - - t.Logf("frame size: %d", n) - rbuf, err := ioutil.ReadAll(r) - if err != nil { - t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) - continue - } - - if len(rbuf) != n { - t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) - continue - } - - for i, b := range rbuf { - if byte(i) != b { - t.Errorf("%s: bad byte at offset %d", name, i) - break - } - } - } - } - } - } - } -} - -func TestControl(t *testing.T) { - const message = "this is a ping/pong messsage" - for _, isServer := range []bool{true, false} { - for _, isWriteControl := range []bool{true, false} { - name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) - var connBuf bytes.Buffer - wc := newTestConn(nil, &connBuf, isServer) - rc := newTestConn(&connBuf, nil, !isServer) - if isWriteControl { - wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) - } else { - w, err := wc.NextWriter(PongMessage) - if err != nil { - t.Errorf("%s: wc.NextWriter() returned %v", name, err) - continue - } - if _, err := w.Write([]byte(message)); err != nil { - t.Errorf("%s: w.Write() returned %v", name, err) - continue - } - if err := w.Close(); err != nil { - t.Errorf("%s: w.Close() returned %v", name, err) - continue - } - var actualMessage string - rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) - rc.NextReader() - if actualMessage != message { - t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) - continue - } - } - } - } -} - -// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. -type simpleBufferPool struct { - v interface{} -} - -func (p *simpleBufferPool) Get() interface{} { - v := p.v - p.v = nil - return v -} - -func (p *simpleBufferPool) Put(v interface{}) { - p.v = v -} - -func TestWriteBufferPool(t *testing.T) { - const message = "Now is the time for all good people to come to the aid of the party." - - var buf bytes.Buffer - var pool simpleBufferPool - rc := newTestConn(&buf, nil, false) - - // Specify writeBufferSize smaller than message size to ensure that pooling - // works with fragmented messages. - wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) - - if wc.writeBuf != nil { - t.Fatal("writeBuf not nil after create") - } - - // Part 1: test NextWriter/Write/Close - - w, err := wc.NextWriter(TextMessage) - if err != nil { - t.Fatalf("wc.NextWriter() returned %v", err) - } - - if wc.writeBuf == nil { - t.Fatal("writeBuf is nil after NextWriter") - } - - writeBufAddr := &wc.writeBuf[0] - - if _, err := io.WriteString(w, message); err != nil { - t.Fatalf("io.WriteString(w, message) returned %v", err) - } - - if err := w.Close(); err != nil { - t.Fatalf("w.Close() returned %v", err) - } - - if wc.writeBuf != nil { - t.Fatal("writeBuf not nil after w.Close()") - } - - if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { - t.Fatal("writeBuf not returned to pool") - } - - opCode, p, err := rc.ReadMessage() - if opCode != TextMessage || err != nil { - t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) - } - - if s := string(p); s != message { - t.Fatalf("message is %s, want %s", s, message) - } - - // Part 2: Test WriteMessage. - - if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { - t.Fatalf("wc.WriteMessage() returned %v", err) - } - - if wc.writeBuf != nil { - t.Fatal("writeBuf not nil after wc.WriteMessage()") - } - - if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { - t.Fatal("writeBuf not returned to pool after WriteMessage") - } - - opCode, p, err = rc.ReadMessage() - if opCode != TextMessage || err != nil { - t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) - } - - if s := string(p); s != message { - t.Fatalf("message is %s, want %s", s, message) - } -} - -// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. -func TestWriteBufferPoolSync(t *testing.T) { - var buf bytes.Buffer - var pool sync.Pool - wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) - rc := newTestConn(&buf, nil, false) - - const message = "Hello World!" - for i := 0; i < 3; i++ { - if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { - t.Fatalf("wc.WriteMessage() returned %v", err) - } - opCode, p, err := rc.ReadMessage() - if opCode != TextMessage || err != nil { - t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) - } - if s := string(p); s != message { - t.Fatalf("message is %s, want %s", s, message) - } - } -} - -// errorWriter is an io.Writer than returns an error on all writes. -type errorWriter struct{} - -func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } - -// TestWriteBufferPoolError ensures that buffer is returned to pool after error -// on write. -func TestWriteBufferPoolError(t *testing.T) { - - // Part 1: Test NextWriter/Write/Close - - var pool simpleBufferPool - wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) - - w, err := wc.NextWriter(TextMessage) - if err != nil { - t.Fatalf("wc.NextWriter() returned %v", err) - } - - if wc.writeBuf == nil { - t.Fatal("writeBuf is nil after NextWriter") - } - - writeBufAddr := &wc.writeBuf[0] - - if _, err := io.WriteString(w, "Hello"); err != nil { - t.Fatalf("io.WriteString(w, message) returned %v", err) - } - - if err := w.Close(); err == nil { - t.Fatalf("w.Close() did not return error") - } - - if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { - t.Fatal("writeBuf not returned to pool") - } - - // Part 2: Test WriteMessage - - wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) - - if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { - t.Fatalf("wc.WriteMessage did not return error") - } - - if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { - t.Fatal("writeBuf not returned to pool") - } -} - -func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { - const bufSize = 512 - - expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} - - var b1, b2 bytes.Buffer - wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) - rc := newTestConn(&b1, &b2, true) - - w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) - wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) - w.Close() - - op, r, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("NextReader() returned %d, %v", op, err) - } - _, err = io.Copy(ioutil.Discard, r) - if !reflect.DeepEqual(err, expectedErr) { - t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) - } - _, _, err = rc.NextReader() - if !reflect.DeepEqual(err, expectedErr) { - t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) - } -} - -func TestEOFWithinFrame(t *testing.T) { - const bufSize = 64 - - for n := 0; ; n++ { - var b bytes.Buffer - wc := newTestConn(nil, &b, false) - rc := newTestConn(&b, nil, true) - - w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize)) - w.Close() - - if n >= b.Len() { - break - } - b.Truncate(n) - - op, r, err := rc.NextReader() - if err == errUnexpectedEOF { - continue - } - if op != BinaryMessage || err != nil { - t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) - } - _, err = io.Copy(ioutil.Discard, r) - if err != errUnexpectedEOF { - t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) - } - _, _, err = rc.NextReader() - if err != errUnexpectedEOF { - t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) - } - } -} - -func TestEOFBeforeFinalFrame(t *testing.T) { - const bufSize = 512 - - var b1, b2 bytes.Buffer - wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) - rc := newTestConn(&b1, &b2, true) - - w, _ := wc.NextWriter(BinaryMessage) - w.Write(make([]byte, bufSize+bufSize/2)) - - op, r, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("NextReader() returned %d, %v", op, err) - } - _, err = io.Copy(ioutil.Discard, r) - if err != errUnexpectedEOF { - t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) - } - _, _, err = rc.NextReader() - if err != errUnexpectedEOF { - t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) - } -} - -func TestWriteAfterMessageWriterClose(t *testing.T) { - wc := newTestConn(nil, &bytes.Buffer{}, false) - w, _ := wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") - if err := w.Close(); err != nil { - t.Fatalf("unxpected error closing message writer, %v", err) - } - - if _, err := io.WriteString(w, "world"); err == nil { - t.Fatalf("no error writing after close") - } - - w, _ = wc.NextWriter(BinaryMessage) - io.WriteString(w, "hello") - - // close w by getting next writer - _, err := wc.NextWriter(BinaryMessage) - if err != nil { - t.Fatalf("unexpected error getting next writer, %v", err) - } - - if _, err := io.WriteString(w, "world"); err == nil { - t.Fatalf("no error writing after close") - } -} - -func TestReadLimit(t *testing.T) { - t.Run("Test ReadLimit is enforced", func(t *testing.T) { - const readLimit = 512 - message := make([]byte, readLimit+1) - - var b1, b2 bytes.Buffer - wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) - rc := newTestConn(&b1, &b2, true) - rc.SetReadLimit(readLimit) - - // Send message at the limit with interleaved pong. - w, _ := wc.NextWriter(BinaryMessage) - w.Write(message[:readLimit-1]) - wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) - w.Write(message[:1]) - w.Close() - - // Send message larger than the limit. - wc.WriteMessage(BinaryMessage, message[:readLimit+1]) - - op, _, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("1: NextReader() returned %d, %v", op, err) - } - op, r, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("2: NextReader() returned %d, %v", op, err) - } - _, err = io.Copy(ioutil.Discard, r) - if err != ErrReadLimit { - t.Fatalf("io.Copy() returned %v", err) - } - }) - - t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { - const readLimit = 1 - - var b1, b2 bytes.Buffer - rc := newTestConn(&b1, &b2, true) - rc.SetReadLimit(readLimit) - - // First, send a non-final binary message - b1.Write([]byte("\x02\x81")) - - // Mask key - b1.Write([]byte("\x00\x00\x00\x00")) - - // First payload - b1.Write([]byte("A")) - - // Next, send a negative-length, non-final continuation frame - b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) - - // Mask key - b1.Write([]byte("\x00\x00\x00\x00")) - - // Next, send a too long, final continuation frame - b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) - - // Mask key - b1.Write([]byte("\x00\x00\x00\x00")) - - // Too-long payload - b1.Write([]byte("BCDEF")) - - op, r, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("1: NextReader() returned %d, %v", op, err) - } - - var buf [10]byte - var read int - n, err := r.Read(buf[:]) - if err != nil && err != ErrReadLimit { - t.Fatalf("unexpected error testing read limit: %v", err) - } - read += n - - n, err = r.Read(buf[:]) - if err != nil && err != ErrReadLimit { - t.Fatalf("unexpected error testing read limit: %v", err) - } - read += n - - if err == nil && read > readLimit { - t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) - } - }) -} - -func TestAddrs(t *testing.T) { - c := newTestConn(nil, nil, true) - if c.LocalAddr() != localAddr { - t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) - } - if c.RemoteAddr() != remoteAddr { - t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) - } -} - -func TestUnderlyingConn(t *testing.T) { - var b1, b2 bytes.Buffer - fc := fakeNetConn{Reader: &b1, Writer: &b2} - c := newConn(fc, true, 1024, 1024, nil, nil, nil) - ul := c.UnderlyingConn() - if ul != fc { - t.Fatalf("Underlying conn is not what it should be.") - } -} - -func TestBufioReadBytes(t *testing.T) { - // Test calling bufio.ReadBytes for value longer than read buffer size. - - m := make([]byte, 512) - m[len(m)-1] = '\n' - - var b1, b2 bytes.Buffer - wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) - rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) - - w, _ := wc.NextWriter(BinaryMessage) - w.Write(m) - w.Close() - - op, r, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("NextReader() returned %d, %v", op, err) - } - - br := bufio.NewReader(r) - p, err := br.ReadBytes('\n') - if err != nil { - t.Fatalf("ReadBytes() returned %v", err) - } - if len(p) != len(m) { - t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) - } -} - -var closeErrorTests = []struct { - err error - codes []int - ok bool -}{ - {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, - {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, - {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, - {errors.New("hello"), []int{CloseNormalClosure}, false}, -} - -func TestCloseError(t *testing.T) { - for _, tt := range closeErrorTests { - ok := IsCloseError(tt.err, tt.codes...) - if ok != tt.ok { - t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) - } - } -} - -var unexpectedCloseErrorTests = []struct { - err error - codes []int - ok bool -}{ - {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, - {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, - {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, - {errors.New("hello"), []int{CloseNormalClosure}, false}, -} - -func TestUnexpectedCloseErrors(t *testing.T) { - for _, tt := range unexpectedCloseErrorTests { - ok := IsUnexpectedCloseError(tt.err, tt.codes...) - if ok != tt.ok { - t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) - } - } -} - -type blockingWriter struct { - c1, c2 chan struct{} -} - -func (w blockingWriter) Write(p []byte) (int, error) { - // Allow main to continue - close(w.c1) - // Wait for panic in main - <-w.c2 - return len(p), nil -} - -func TestConcurrentWritePanic(t *testing.T) { - w := blockingWriter{make(chan struct{}), make(chan struct{})} - c := newTestConn(nil, w, false) - go func() { - c.WriteMessage(TextMessage, []byte{}) - }() - - // wait for goroutine to block in write. - <-w.c1 - - defer func() { - close(w.c2) - if v := recover(); v != nil { - return - } - }() - - c.WriteMessage(TextMessage, []byte{}) - t.Fatal("should not get here") -} - -type failingReader struct{} - -func (r failingReader) Read(p []byte) (int, error) { - return 0, io.EOF -} - -func TestFailedConnectionReadPanic(t *testing.T) { - c := newTestConn(failingReader{}, nil, false) - - defer func() { - if v := recover(); v != nil { - return - } - }() - - for i := 0; i < 20000; i++ { - c.ReadMessage() - } - t.Fatal("should not get here") -}