@ -260,10 +260,12 @@ type Conn struct {
newCompressionWriter func ( io . WriteCloser , int ) io . WriteCloser
// Read fields
reader io . ReadCloser // the current reader returned to the application
readErr error
br * bufio . Reader
readRemaining int64 // bytes remaining in current frame.
reader io . ReadCloser // the current reader returned to the application
readErr error
br * bufio . Reader
// bytes remaining in current frame.
// set setReadRemaining to safely update this value and prevent overflow
readRemaining int64
readFinal bool // true the current message has more frames.
readLength int64 // Message size.
readLimit int64 // Maximum message size.
@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
return c
}
// setReadRemaining tracks the number of bytes remaining on the connection. If n
// overflows, an ErrReadLimit is returned.
func ( c * Conn ) setReadRemaining ( n int64 ) error {
if n < 0 {
return ErrReadLimit
}
c . readRemaining = n
return nil
}
// Subprotocol returns the negotiated protocol for the connection.
func ( c * Conn ) Subprotocol ( ) string {
return c . subprotocol
@ -770,7 +783,7 @@ func (c *Conn) advanceFrame() (int, error) {
final := p [ 0 ] & finalBit != 0
frameType := int ( p [ 0 ] & 0xf )
mask := p [ 1 ] & maskBit != 0
c . readRemaining = int64 ( p [ 1 ] & 0x7f )
c . setReadRemaining ( int64 ( p [ 1 ] & 0x7f ) )
c . readDecompress = false
if c . newDecompressionReader != nil && ( p [ 0 ] & rsv1Bit ) != 0 {
@ -804,7 +817,17 @@ func (c *Conn) advanceFrame() (int, error) {
return noFrame , c . handleProtocolError ( "unknown opcode " + strconv . Itoa ( frameType ) )
}
// 3. Read and parse frame length.
// 3. Read and parse frame length as per
// https://tools.ietf.org/html/rfc6455#section-5.2
//
// The length of the "Payload data", in bytes: if 0-125, that is the payload
// length.
// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
// integer are the payload length.
// - If 127, the following 8 bytes interpreted as
// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
// payload length. Multibyte length quantities are expressed in network byte
// order.
switch c . readRemaining {
case 126 :
@ -812,13 +835,19 @@ func (c *Conn) advanceFrame() (int, error) {
if err != nil {
return noFrame , err
}
c . readRemaining = int64 ( binary . BigEndian . Uint16 ( p ) )
if err := c . setReadRemaining ( int64 ( binary . BigEndian . Uint16 ( p ) ) ) ; err != nil {
return noFrame , err
}
case 127 :
p , err := c . read ( 8 )
if err != nil {
return noFrame , err
}
c . readRemaining = int64 ( binary . BigEndian . Uint64 ( p ) )
if err := c . setReadRemaining ( int64 ( binary . BigEndian . Uint64 ( p ) ) ) ; err != nil {
return noFrame , err
}
}
// 4. Handle frame masking.
@ -841,6 +870,12 @@ func (c *Conn) advanceFrame() (int, error) {
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
c . readLength += c . readRemaining
// Don't allow readLength to overflow in the presence of a large readRemaining
// counter.
if c . readLength < 0 {
return noFrame , ErrReadLimit
}
if c . readLimit > 0 && c . readLength > c . readLimit {
c . WriteControl ( CloseMessage , FormatCloseMessage ( CloseMessageTooBig , "" ) , time . Now ( ) . Add ( writeWait ) )
return noFrame , ErrReadLimit
@ -854,7 +889,7 @@ func (c *Conn) advanceFrame() (int, error) {
var payload [ ] byte
if c . readRemaining > 0 {
payload , err = c . read ( int ( c . readRemaining ) )
c . readRemaining = 0
c . setReadRemaining ( 0 )
if err != nil {
return noFrame , err
}
@ -927,6 +962,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c . readErr = hideTempErr ( err )
break
}
if frameType == TextMessage || frameType == BinaryMessage {
c . messageReader = & messageReader { c }
c . reader = c . messageReader
@ -967,7 +1003,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
if c . isServer {
c . readMaskPos = maskBytes ( c . readMaskKey , c . readMaskPos , b [ : n ] )
}
c . readRemaining -= int64 ( n )
rem := c . readRemaining
rem -= int64 ( n )
c . setReadRemaining ( rem )
if c . readRemaining > 0 && c . readErr == io . EOF {
c . readErr = errUnexpectedEOF
}