diff --git a/vendor/github.com/lucas-clemente/quic-go/client.go b/vendor/github.com/lucas-clemente/quic-go/client.go index c5d66732..bf9a7959 100644 --- a/vendor/github.com/lucas-clemente/quic-go/client.go +++ b/vendor/github.com/lucas-clemente/quic-go/client.go @@ -27,7 +27,7 @@ type client struct { token []byte - versionNegotiated bool // has the server accepted our version + versionNegotiated utils.AtomicBool // has the server accepted our version receivedVersionNegotiationPacket bool negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet @@ -59,6 +59,7 @@ var ( ) // DialAddr establishes a new QUIC connection to a server. +// It uses a new UDP connection and closes this connection when the QUIC session is closed. // The hostname for SNI is taken from the given address. func DialAddr( addr string, @@ -69,7 +70,7 @@ func DialAddr( } // DialAddrContext establishes a new QUIC connection to a server using the provided context. -// The hostname for SNI is taken from the given address. +// See DialAddr for details. func DialAddrContext( ctx context.Context, addr string, @@ -88,6 +89,8 @@ func DialAddrContext( } // Dial establishes a new QUIC connection to a server using a net.PacketConn. +// The same PacketConn can be used for multiple calls to Dial and Listen, +// QUIC connection IDs are used for demultiplexing the different connections. // The host parameter is used for SNI. func Dial( pconn net.PacketConn, @@ -100,7 +103,7 @@ func Dial( } // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. -// The host parameter is used for SNI. +// See Dial for details. func DialContext( ctx context.Context, pconn net.PacketConn, @@ -164,7 +167,18 @@ func newClient( } } } + + srcConnID, err := generateConnectionID(config.ConnectionIDLength) + if err != nil { + return nil, err + } + destConnID, err := generateConnectionIDForInitial() + if err != nil { + return nil, err + } c := &client{ + srcConnID: srcConnID, + destConnID: destConnID, conn: &conn{pconn: pconn, currentAddr: remoteAddr}, createdPacketConn: createdPacketConn, tlsConf: tlsConf, @@ -173,7 +187,7 @@ func newClient( handshakeChan: make(chan struct{}), logger: utils.DefaultLogger.WithPrefix("client"), } - return c, c.generateConnectionIDs() + return c, nil } // populateClientConfig populates fields in the quic.Config with their default values, if none are set @@ -234,20 +248,6 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config { } } -func (c *client) generateConnectionIDs() error { - srcConnID, err := generateConnectionID(c.config.ConnectionIDLength) - if err != nil { - return err - } - destConnID, err := generateConnectionIDForInitial() - if err != nil { - return err - } - c.srcConnID = srcConnID - c.destConnID = destConnID - return nil -} - func (c *client) dial(ctx context.Context) error { c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) @@ -292,65 +292,49 @@ func (c *client) establishSecureConnection(ctx context.Context) error { } func (c *client) handlePacket(p *receivedPacket) { - if err := c.handlePacketImpl(p); err != nil { - c.logger.Errorf("error handling packet: %s", err) - } -} - -func (c *client) handlePacketImpl(p *receivedPacket) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - // handle Version Negotiation Packets - if p.header.IsVersionNegotiation { - err := c.handleVersionNegotiationPacket(p.header) - if err != nil { - c.session.destroy(err) - } - // version negotiation packets have no payload - return err + if p.hdr.IsVersionNegotiation() { + go c.handleVersionNegotiationPacket(p.hdr) + return } - // reject packets with the wrong connection ID - if !p.header.DestConnectionID.Equal(c.srcConnID) { - return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID) - } - - if p.header.Type == protocol.PacketTypeRetry { - c.handleRetryPacket(p.header) - return nil + if p.hdr.Type == protocol.PacketTypeRetry { + go c.handleRetryPacket(p.hdr) + return } // this is the first packet we are receiving // since it is not a Version Negotiation Packet, this means the server supports the suggested version - if !c.versionNegotiated { - c.versionNegotiated = true + if !c.versionNegotiated.Get() { + c.versionNegotiated.Set(true) } c.session.handlePacket(p) - return nil } -func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { +func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) { + c.mutex.Lock() + defer c.mutex.Unlock() + // ignore delayed / duplicated version negotiation packets - if c.receivedVersionNegotiationPacket || c.versionNegotiated { - c.logger.Debugf("Received a delayed Version Negotiation Packet.") - return nil + if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() { + c.logger.Debugf("Received a delayed Version Negotiation packet.") + return } for _, v := range hdr.SupportedVersions { if v == c.version { - // the version negotiation packet contains the version that we offered - // this might be a packet sent by an attacker (or by a terribly broken server implementation) - // ignore it - return nil + // The Version Negotiation packet contains the version that we offered. + // This might be a packet sent by an attacker (or by a terribly broken server implementation). + return } } - c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions) + c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) if !ok { - return qerr.InvalidVersion + c.session.destroy(qerr.InvalidVersion) + c.logger.Debugf("No compatible version found.") + return } c.receivedVersionNegotiationPacket = true c.negotiatedVersions = hdr.SupportedVersions @@ -358,18 +342,17 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error { // switch to negotiated version c.initialVersion = c.version c.version = newVersion - if err := c.generateConnectionIDs(); err != nil { - return err - } c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) c.session.destroy(errCloseSessionForNewVersion) - return nil } func (c *client) handleRetryPacket(hdr *wire.Header) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.logger.Debugf("<- Received Retry") - hdr.Log(c.logger) + (&wire.ExtendedHeader{Header: *hdr}).Log(c.logger) if !hdr.OrigDestConnectionID.Equal(c.destConnID) { c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID) return diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go index 0e253354..f155bdd9 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go @@ -75,12 +75,10 @@ type sentPacketHandler struct { alarm time.Time logger utils.Logger - - version protocol.VersionNumber } // NewSentPacketHandler creates a new sentPacketHandler -func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler { +func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler { congestion := congestion.NewCubicSender( congestion.DefaultClock{}, rttStats, @@ -95,7 +93,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve rttStats: rttStats, congestion: congestion, logger: logger, - version: version, } } @@ -516,7 +513,7 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) { func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) { pn := h.packetNumberGenerator.Peek() - return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version) + return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked()) } func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/qtls.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/qtls.go index fb2f0bd4..2cffd92b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/qtls.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/qtls.go @@ -11,11 +11,13 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config { c = &tls.Config{} } // QUIC requires TLS 1.3 or newer - if c.MinVersion < qtls.VersionTLS13 { - c.MinVersion = qtls.VersionTLS13 + minVersion := c.MinVersion + if minVersion < qtls.VersionTLS13 { + minVersion = qtls.VersionTLS13 } - if c.MaxVersion < qtls.VersionTLS13 { - c.MaxVersion = qtls.VersionTLS13 + maxVersion := c.MaxVersion + if maxVersion < qtls.VersionTLS13 { + maxVersion = qtls.VersionTLS13 } return &qtls.Config{ Rand: c.Rand, @@ -38,8 +40,8 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config { PreferServerCipherSuites: c.PreferServerCipherSuites, SessionTicketsDisabled: c.SessionTicketsDisabled, SessionTicketKey: c.SessionTicketKey, - MinVersion: c.MinVersion, - MaxVersion: c.MaxVersion, + MinVersion: minVersion, + MaxVersion: maxVersion, CurvePreferences: c.CurvePreferences, DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, Renegotiation: c.Renegotiation, diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go index e32d6baa..17f68055 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go @@ -1,20 +1,37 @@ package protocol +// PacketNumberLen is the length of the packet number in bytes +type PacketNumberLen uint8 + +const ( + // PacketNumberLenInvalid is the default value and not a valid length for a packet number + PacketNumberLenInvalid PacketNumberLen = 0 + // PacketNumberLen1 is a packet number length of 1 byte + PacketNumberLen1 PacketNumberLen = 1 + // PacketNumberLen2 is a packet number length of 2 bytes + PacketNumberLen2 PacketNumberLen = 2 + // PacketNumberLen3 is a packet number length of 3 bytes + PacketNumberLen3 PacketNumberLen = 3 + // PacketNumberLen4 is a packet number length of 4 bytes + PacketNumberLen4 PacketNumberLen = 4 +) + // InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number func InferPacketNumber( packetNumberLength PacketNumberLen, lastPacketNumber PacketNumber, wirePacketNumber PacketNumber, - version VersionNumber, ) PacketNumber { var epochDelta PacketNumber switch packetNumberLength { case PacketNumberLen1: - epochDelta = PacketNumber(1) << 7 + epochDelta = PacketNumber(1) << 8 case PacketNumberLen2: - epochDelta = PacketNumber(1) << 14 + epochDelta = PacketNumber(1) << 16 + case PacketNumberLen3: + epochDelta = PacketNumber(1) << 24 case PacketNumberLen4: - epochDelta = PacketNumber(1) << 30 + epochDelta = PacketNumber(1) << 32 } epoch := lastPacketNumber & ^(epochDelta - 1) prevEpochBegin := epoch - epochDelta @@ -42,11 +59,14 @@ func delta(a, b PacketNumber) PacketNumber { // GetPacketNumberLengthForHeader gets the length of the packet number for the public header // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances -func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen { +func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen { diff := uint64(packetNumber - leastUnacked) - if diff < (1 << (14 - 1)) { + if diff < (1 << (16 - 1)) { return PacketNumberLen2 } + if diff < (1 << (24 - 1)) { + return PacketNumberLen3 + } return PacketNumberLen4 } @@ -58,5 +78,8 @@ func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen { if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) { return PacketNumberLen2 } + if packetNumber < (1 << (uint8(PacketNumberLen3) * 8)) { + return PacketNumberLen3 + } return PacketNumberLen4 } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go index 77e1fb11..929708db 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go @@ -7,32 +7,18 @@ import ( // A PacketNumber in QUIC type PacketNumber uint64 -// PacketNumberLen is the length of the packet number in bytes -type PacketNumberLen uint8 - -const ( - // PacketNumberLenInvalid is the default value and not a valid length for a packet number - PacketNumberLenInvalid PacketNumberLen = 0 - // PacketNumberLen1 is a packet number length of 1 byte - PacketNumberLen1 PacketNumberLen = 1 - // PacketNumberLen2 is a packet number length of 2 bytes - PacketNumberLen2 PacketNumberLen = 2 - // PacketNumberLen4 is a packet number length of 4 bytes - PacketNumberLen4 PacketNumberLen = 4 -) - // The PacketType is the Long Header Type type PacketType uint8 const ( // PacketTypeInitial is the packet type of an Initial packet - PacketTypeInitial PacketType = 0x7f + PacketTypeInitial PacketType = 1 + iota // PacketTypeRetry is the packet type of a Retry packet - PacketTypeRetry PacketType = 0x7e + PacketTypeRetry // PacketTypeHandshake is the packet type of a Handshake packet - PacketTypeHandshake PacketType = 0x7d + PacketTypeHandshake // PacketType0RTT is the packet type of a 0-RTT packet - PacketType0RTT PacketType = 0x7c + PacketType0RTT ) func (t PacketType) String() string { @@ -72,11 +58,5 @@ const DefaultTCPMSS ByteCount = 1460 // MinInitialPacketSize is the minimum size an Initial packet is required to have. const MinInitialPacketSize = 1200 -// MaxClientHellos is the maximum number of times we'll send a client hello -// The value 3 accounts for: -// * one failure due to an incorrect or missing source-address token -// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token -const MaxClientHellos = 3 - // MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. const MinConnectionIDLenInitial = 8 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go index b4a44517..6b92cfa2 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go @@ -8,11 +8,10 @@ import ( // A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. type ByteOrder interface { ReadUintN(b io.ByteReader, length uint8) (uint64, error) - ReadUint64(io.ByteReader) (uint64, error) ReadUint32(io.ByteReader) (uint32, error) ReadUint16(io.ByteReader) (uint16, error) - WriteUint64(*bytes.Buffer, uint64) + WriteUintN(b *bytes.Buffer, length uint8, value uint64) WriteUint32(*bytes.Buffer, uint32) WriteUint16(*bytes.Buffer, uint16) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go index 8ee6e1ab..eede9cd7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go @@ -25,37 +25,6 @@ func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) { return res, nil } -// ReadUint64 reads a uint64 -func (bigEndian) ReadUint64(b io.ByteReader) (uint64, error) { - var b1, b2, b3, b4, b5, b6, b7, b8 uint8 - var err error - if b8, err = b.ReadByte(); err != nil { - return 0, err - } - if b7, err = b.ReadByte(); err != nil { - return 0, err - } - if b6, err = b.ReadByte(); err != nil { - return 0, err - } - if b5, err = b.ReadByte(); err != nil { - return 0, err - } - if b4, err = b.ReadByte(); err != nil { - return 0, err - } - if b3, err = b.ReadByte(); err != nil { - return 0, err - } - if b2, err = b.ReadByte(); err != nil { - return 0, err - } - if b1, err = b.ReadByte(); err != nil { - return 0, err - } - return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil -} - // ReadUint32 reads a uint32 func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) { var b1, b2, b3, b4 uint8 @@ -88,12 +57,10 @@ func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { return uint16(b1) + uint16(b2)<<8, nil } -// WriteUint64 writes a uint64 -func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) { - b.Write([]byte{ - uint8(i >> 56), uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), - uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), - }) +func (bigEndian) WriteUintN(b *bytes.Buffer, length uint8, i uint64) { + for j := length; j > 0; j-- { + b.WriteByte(uint8(i >> (8 * (j - 1)))) + } } // WriteUint32 writes a uint32 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go new file mode 100644 index 00000000..a08dd487 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go @@ -0,0 +1,205 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// ExtendedHeader is the header of a QUIC packet. +type ExtendedHeader struct { + Header + + typeByte byte + Raw []byte + + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber + + KeyPhase int +} + +func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { + // read the (now unencrypted) first byte + var err error + h.typeByte, err = b.ReadByte() + if err != nil { + return nil, err + } + if _, err := b.Seek(int64(h.len)-1, io.SeekCurrent); err != nil { + return nil, err + } + if h.IsLongHeader { + return h.parseLongHeader(b, v) + } + return h.parseShortHeader(b, v) +} + +func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { + if h.typeByte&0xc != 0 { + return nil, errors.New("5th and 6th bit must be 0") + } + if err := h.readPacketNumber(b); err != nil { + return nil, err + } + return h, nil +} + +func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) { + if h.typeByte&0x18 != 0 { + return nil, errors.New("4th and 5th bit must be 0") + } + + h.KeyPhase = int(h.typeByte&0x4) >> 2 + + if err := h.readPacketNumber(b); err != nil { + return nil, err + } + return h, nil +} + +func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { + h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 + pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen)) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(pn) + return nil +} + +// Write writes the Header. +func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { + if h.IsLongHeader { + return h.writeLongHeader(b, ver) + } + return h.writeShortHeader(b, ver) +} + +func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error { + var packetType uint8 + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0x0 + case protocol.PacketType0RTT: + packetType = 0x1 + case protocol.PacketTypeHandshake: + packetType = 0x2 + case protocol.PacketTypeRetry: + packetType = 0x3 + } + firstByte := 0xc0 | packetType<<4 + if h.Type == protocol.PacketTypeRetry { + odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID) + if err != nil { + return err + } + firstByte |= odcil + } else { // Retry packets don't have a packet number + firstByte |= uint8(h.PacketNumberLen - 1) + } + + b.WriteByte(firstByte) + utils.BigEndian.WriteUint32(b, uint32(h.Version)) + connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID) + if err != nil { + return err + } + b.WriteByte(connIDLen) + b.Write(h.DestConnectionID.Bytes()) + b.Write(h.SrcConnectionID.Bytes()) + + switch h.Type { + case protocol.PacketTypeRetry: + b.Write(h.OrigDestConnectionID.Bytes()) + b.Write(h.Token) + return nil + case protocol.PacketTypeInitial: + utils.WriteVarInt(b, uint64(len(h.Token))) + b.Write(h.Token) + } + + utils.WriteVarInt(b, uint64(h.Length)) + return h.writePacketNumber(b) +} + +// TODO: add support for the key phase +func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error { + typeByte := 0x40 | uint8(h.PacketNumberLen-1) + typeByte |= byte(h.KeyPhase << 2) + + b.WriteByte(typeByte) + b.Write(h.DestConnectionID.Bytes()) + return h.writePacketNumber(b) +} + +func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { + if h.PacketNumberLen == protocol.PacketNumberLenInvalid || h.PacketNumberLen > protocol.PacketNumberLen4 { + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + utils.BigEndian.WriteUintN(b, uint8(h.PacketNumberLen), uint64(h.PacketNumber)) + return nil +} + +// GetLength determines the length of the Header. +func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount { + if h.IsLongHeader { + length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length)) + if h.Type == protocol.PacketTypeInitial { + length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) + } + return length + } + + length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) + length += protocol.ByteCount(h.PacketNumberLen) + return length +} + +// Log logs the Header +func (h *ExtendedHeader) Log(logger utils.Logger) { + if h.IsLongHeader { + var token string + if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { + if len(h.Token) == 0 { + token = "Token: (empty), " + } else { + token = fmt.Sprintf("Token: %#x, ", h.Token) + } + if h.Type == protocol.PacketTypeRetry { + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version) + return + } + } + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) + } else { + logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + } +} + +func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) { + dcil, err := encodeSingleConnIDLen(dest) + if err != nil { + return 0, err + } + scil, err := encodeSingleConnIDLen(src) + if err != nil { + return 0, err + } + return scil | dcil<<4, nil +} + +func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) { + len := id.Len() + if len == 0 { + return 0, nil + } + if len < 4 || len > 18 { + return 0, fmt.Errorf("invalid connection ID length: %d bytes", len) + } + return byte(len - 3), nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go index 36255fdd..c40d40b2 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go @@ -2,150 +2,183 @@ package wire import ( "bytes" - "crypto/rand" - "fmt" + "errors" + "io" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/utils" ) -// Header is the header of a QUIC packet. +// The Header is the version independent part of the header type Header struct { - Raw []byte + Version protocol.VersionNumber + SrcConnectionID protocol.ConnectionID + DestConnectionID protocol.ConnectionID - Version protocol.VersionNumber - - DestConnectionID protocol.ConnectionID - SrcConnectionID protocol.ConnectionID - OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet - - PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber - - IsVersionNegotiation bool - SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server - - Type protocol.PacketType IsLongHeader bool - KeyPhase int + Type protocol.PacketType Length protocol.ByteCount - Token []byte + + Token []byte + SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet + OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet + + typeByte byte + len int // how many bytes were read while parsing this header } -// Write writes the Header. -func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error { - if h.IsLongHeader { - return h.writeLongHeader(b, ver) +// ParseHeader parses the header. +// For short header packets: up to the packet number. +// For long header packets: +// * if we understand the version: up to the packet number +// * if not, only the invariant part of the header +func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { + startLen := b.Len() + h, err := parseHeaderImpl(b, shortHeaderConnIDLen) + if err != nil { + return nil, err } - return h.writeShortHeader(b, ver) + h.len = startLen - b.Len() + return h, nil } -// TODO: add support for the key phase -func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error { - b.WriteByte(byte(0x80 | h.Type)) - utils.BigEndian.WriteUint32(b, uint32(h.Version)) - connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID) +func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + + h := &Header{ + typeByte: typeByte, + IsLongHeader: typeByte&0x80 > 0, + } + + if !h.IsLongHeader { + if h.typeByte&0x40 == 0 { + return nil, errors.New("not a QUIC packet") + } + if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil { + return nil, err + } + return h, nil + } + if err := h.parseLongHeader(b); err != nil { + return nil, err + } + return h, nil +} + +func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error { + var err error + h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen) + return err +} + +func (h *Header) parseLongHeader(b *bytes.Reader) error { + v, err := utils.BigEndian.ReadUint32(b) if err != nil { return err } - b.WriteByte(connIDLen) - b.Write(h.DestConnectionID.Bytes()) - b.Write(h.SrcConnectionID.Bytes()) - - if h.Type == protocol.PacketTypeInitial { - utils.WriteVarInt(b, uint64(len(h.Token))) - b.Write(h.Token) + h.Version = protocol.VersionNumber(v) + if !h.IsVersionNegotiation() && h.typeByte&0x40 == 0 { + return errors.New("not a QUIC packet") } - - if h.Type == protocol.PacketTypeRetry { - odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID) - if err != nil { - return err - } - // randomize the first 4 bits - odcilByte := make([]byte, 1) - _, _ = rand.Read(odcilByte) // it's safe to ignore the error here - odcilByte[0] = (odcilByte[0] & 0xf0) | odcil - b.Write(odcilByte) - b.Write(h.OrigDestConnectionID.Bytes()) - b.Write(h.Token) + connIDLenByte, err := b.ReadByte() + if err != nil { + return err + } + dcil, scil := decodeConnIDLen(connIDLenByte) + h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil) + if err != nil { + return err + } + h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil) + if err != nil { + return err + } + if h.Version == 0 { + return h.parseVersionNegotiationPacket(b) + } + // If we don't understand the version, we have no idea how to interpret the rest of the bytes + if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { return nil } - utils.WriteVarInt(b, uint64(h.Length)) - return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen) -} + switch (h.typeByte & 0x30) >> 4 { + case 0x0: + h.Type = protocol.PacketTypeInitial + case 0x1: + h.Type = protocol.PacketType0RTT + case 0x2: + h.Type = protocol.PacketTypeHandshake + case 0x3: + h.Type = protocol.PacketTypeRetry + } -func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error { - typeByte := byte(0x30) - typeByte |= byte(h.KeyPhase << 6) - - b.WriteByte(typeByte) - b.Write(h.DestConnectionID.Bytes()) - return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen) -} - -// GetLength determines the length of the Header. -func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount { - if h.IsLongHeader { - length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length)) - if h.Type == protocol.PacketTypeInitial { - length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) + if h.Type == protocol.PacketTypeRetry { + odcil := decodeSingleConnIDLen(h.typeByte & 0xf) + h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil) + if err != nil { + return err } - return length - } - - length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) - length += protocol.ByteCount(h.PacketNumberLen) - return length -} - -// Log logs the Header -func (h *Header) Log(logger utils.Logger) { - if h.IsLongHeader { - if h.Version == 0 { - logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions) - } else { - var token string - if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { - if len(h.Token) == 0 { - token = "Token: (empty), " - } else { - token = fmt.Sprintf("Token: %#x, ", h.Token) - } - } - if h.Type == protocol.PacketTypeRetry { - logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version) - return - } - logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) + h.Token = make([]byte, b.Len()) + if _, err := io.ReadFull(b, h.Token); err != nil { + return err } - } else { - logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + return nil } + + if h.Type == protocol.PacketTypeInitial { + tokenLen, err := utils.ReadVarInt(b) + if err != nil { + return err + } + if tokenLen > uint64(b.Len()) { + return io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return err + } + } + + pl, err := utils.ReadVarInt(b) + if err != nil { + return err + } + h.Length = protocol.ByteCount(pl) + return nil } -func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) { - dcil, err := encodeSingleConnIDLen(dest) - if err != nil { - return 0, err +func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error { + if b.Len() == 0 { + return qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list") } - scil, err := encodeSingleConnIDLen(src) - if err != nil { - return 0, err + h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4) + for i := 0; b.Len() > 0; i++ { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return qerr.InvalidVersionNegotiationPacket + } + h.SupportedVersions[i] = protocol.VersionNumber(v) } - return scil | dcil<<4, nil + return nil } -func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) { - len := id.Len() - if len == 0 { - return 0, nil - } - if len < 4 || len > 18 { - return 0, fmt.Errorf("invalid connection ID length: %d bytes", len) - } - return byte(len - 3), nil +// IsVersionNegotiation says if this a version negotiation packet +func (h *Header) IsVersionNegotiation() bool { + return h.IsLongHeader && h.Version == 0 +} + +// ParseExtended parses the version dependent part of the header. +// The Reader has to be set such that it points to the first byte of the header. +func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { + return h.toExtendedHeader().parse(b, ver) +} + +func (h *Header) toExtendedHeader() *ExtendedHeader { + return &ExtendedHeader{Header: *h} } func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) { diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go b/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go index a4acb6fd..7e411df0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go @@ -162,75 +162,46 @@ func (h *packetHandlerMap) listen() { } func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error { - rcvTime := time.Now() - r := bytes.NewReader(data) - iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen) + hdr, err := wire.ParseHeader(r, h.connIDLen) // drop the packet if we can't parse the header - if err != nil { - return fmt.Errorf("error parsing invariant header: %s", err) - } - - h.mutex.RLock() - handlerEntry, handlerFound := h.handlers[string(iHdr.DestConnectionID)] - server := h.server - - var sentBy protocol.Perspective - var version protocol.VersionNumber - var handlePacket func(*receivedPacket) - if handlerFound { // existing session - handler := handlerEntry.handler - sentBy = handler.GetPerspective().Opposite() - version = handler.GetVersion() - handlePacket = handler.handlePacket - } else { // no session found - // this might be a stateless reset - if !iHdr.IsLongHeader { - if len(data) >= protocol.MinStatelessResetSize { - var token [16]byte - copy(token[:], data[len(data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - h.mutex.RUnlock() - sess.destroy(errors.New("received a stateless reset")) - return nil - } - } - // TODO(#943): send a stateless reset - return fmt.Errorf("received a short header packet with an unexpected connection ID %s", iHdr.DestConnectionID) - } - if server == nil { // no server set - h.mutex.RUnlock() - return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) - } - handlePacket = server.handlePacket - sentBy = protocol.PerspectiveClient - version = iHdr.Version - } - h.mutex.RUnlock() - - hdr, err := iHdr.Parse(r, sentBy, version) if err != nil { return fmt.Errorf("error parsing header: %s", err) } - hdr.Raw = data[:len(data)-r.Len()] - packetData := data[len(data)-r.Len():] - if hdr.IsLongHeader { - if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) { - return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen) - } - if protocol.ByteCount(len(packetData))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length { - return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(packetData)+int(hdr.PacketNumberLen), hdr.Length) - } - packetData = packetData[:int(hdr.Length)-int(hdr.PacketNumberLen)] - // TODO(#1312): implement parsing of compound packets + p := &receivedPacket{ + remoteAddr: addr, + hdr: hdr, + data: data, + rcvTime: time.Now(), } - handlePacket(&receivedPacket{ - remoteAddr: addr, - header: hdr, - data: packetData, - rcvTime: rcvTime, - }) + h.mutex.RLock() + defer h.mutex.RUnlock() + + handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)] + + if handlerFound { // existing session + handlerEntry.handler.handlePacket(p) + return nil + } + // No session found. + // This might be a stateless reset. + if !hdr.IsLongHeader { + if len(data) >= protocol.MinStatelessResetSize { + var token [16]byte + copy(token[:], data[len(data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + sess.destroy(errors.New("received a stateless reset")) + return nil + } + } + // TODO(#943): send a stateless reset + return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID) + } + if h.server == nil { // no server set + return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID) + } + h.server.handlePacket(p) return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go index e1510341..a87fcc80 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go @@ -25,7 +25,7 @@ type packer interface { } type packedPacket struct { - header *wire.Header + header *wire.ExtendedHeader raw []byte frames []wire.Frame encryptionLevel protocol.EncryptionLevel @@ -397,14 +397,13 @@ func (p *packetPacker) composeNextPacket( return frames, nil } -func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header { +func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber() - header := &wire.Header{ - PacketNumber: pn, - PacketNumberLen: pnLen, - Version: p.version, - DestConnectionID: p.destConnID, - } + header := &wire.ExtendedHeader{} + header.PacketNumber = pn + header.PacketNumberLen = pnLen + header.Version = p.version + header.DestConnectionID = p.destConnID if encLevel != protocol.Encryption1RTT { header.IsLongHeader = true @@ -424,8 +423,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header } func (p *packetPacker) writeAndSealPacket( - header *wire.Header, - frames []wire.Frame, + header *wire.ExtendedHeader, frames []wire.Frame, sealer handshake.Sealer, ) ([]byte, error) { raw := *getPacketBuffer() @@ -450,7 +448,7 @@ func (p *packetPacker) writeAndSealPacket( } } - if err := header.Write(buffer, p.perspective, p.version); err != nil { + if err := header.Write(buffer, p.version); err != nil { return nil, err } payloadStartIndex := buffer.Len() diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go index f073395f..52aa4759 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go +++ b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go @@ -35,7 +35,7 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker { } } -func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) { +func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) { buf := *getPacketBuffer() buf = buf[:0] defer putPacketBuffer(&buf) diff --git a/vendor/github.com/lucas-clemente/quic-go/server.go b/vendor/github.com/lucas-clemente/quic-go/server.go index 5f862c5b..33a3123c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server.go +++ b/vendor/github.com/lucas-clemente/quic-go/server.go @@ -21,7 +21,6 @@ type packetHandler interface { handlePacket(*receivedPacket) io.Closer destroy(error) - GetVersion() protocol.VersionNumber GetPerspective() protocol.Perspective } @@ -99,7 +98,8 @@ var _ Listener = &server{} var _ unknownPacketHandler = &server{} // ListenAddr creates a QUIC server listening on a given address. -// The tls.Config must not be nil, the quic.Config may be nil. +// The tls.Config must not be nil and must contain a certificate configuration. +// The quic.Config may be nil, in that case the default values will be used. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { @@ -118,7 +118,11 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err } // Listen listens for QUIC connections on a given net.PacketConn. -// The tls.Config must not be nil, the quic.Config may be nil. +// A single PacketConn only be used for a single call to Listen. +// The PacketConn can be used for simultaneous calls to Dial. +// QUIC connection IDs are used for demultiplexing the different connections. +// The tls.Config must not be nil and must contain a certificate configuration. +// The quic.Config may be nil, in that case the default values will be used. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { return listen(conn, tlsConf, config) } @@ -300,23 +304,17 @@ func (s *server) Addr() net.Addr { } func (s *server) handlePacket(p *receivedPacket) { - if err := s.handlePacketImpl(p); err != nil { - s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err) - } -} - -func (s *server) handlePacketImpl(p *receivedPacket) error { - hdr := p.header + hdr := p.hdr // send a Version Negotiation Packet if the client is speaking a different protocol version if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - return s.sendVersionNegotiationPacket(p) + go s.sendVersionNegotiationPacket(p) + return } if hdr.Type == protocol.PacketTypeInitial { go s.handleInitial(p) } // TODO(#943): send Stateless Reset - return nil } func (s *server) handleInitial(p *receivedPacket) { @@ -335,11 +333,11 @@ func (s *server) handleInitial(p *receivedPacket) { } func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) { - hdr := p.header + hdr := p.hdr if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { return nil, nil, errors.New("dropping Initial packet with too short connection ID") } - if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize { + if len(p.data) < protocol.MinInitialPacketSize { return nil, nil, errors.New("dropping too small Initial packet") } @@ -358,7 +356,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con if !s.config.AcceptCookie(p.remoteAddr, cookie) { // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the session. - p.header.Log(s.logger) + (&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger) return nil, nil, s.sendRetry(p.remoteAddr, hdr) } @@ -431,19 +429,18 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { if err != nil { return err } - replyHdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - Version: hdr.Version, - SrcConnectionID: connID, - DestConnectionID: hdr.SrcConnectionID, - OrigDestConnectionID: hdr.DestConnectionID, - Token: token, - } + replyHdr := &wire.ExtendedHeader{} + replyHdr.IsLongHeader = true + replyHdr.Type = protocol.PacketTypeRetry + replyHdr.Version = hdr.Version + replyHdr.SrcConnectionID = connID + replyHdr.DestConnectionID = hdr.SrcConnectionID + replyHdr.OrigDestConnectionID = hdr.DestConnectionID + replyHdr.Token = token s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID) replyHdr.Log(s.logger) buf := &bytes.Buffer{} - if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil { + if err := replyHdr.Write(buf, hdr.Version); err != nil { return err } if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil { @@ -452,14 +449,15 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { return nil } -func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error { - hdr := p.header - s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) - +func (s *server) sendVersionNegotiationPacket(p *receivedPacket) { + hdr := p.hdr + s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) if err != nil { - return err + s.logger.Debugf("Error composing Version Negotiation: %s", err) + return + } + if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil { + s.logger.Debugf("Error sending Version Negotiation: %s", err) } - _, err = s.conn.WriteTo(data, p.remoteAddr) - return err } diff --git a/vendor/github.com/lucas-clemente/quic-go/server_session.go b/vendor/github.com/lucas-clemente/quic-go/server_session.go index 0ba04680..d1ab73a4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/server_session.go +++ b/vendor/github.com/lucas-clemente/quic-go/server_session.go @@ -32,7 +32,7 @@ func (s *serverSession) handlePacket(p *receivedPacket) { } func (s *serverSession) handlePacketImpl(p *receivedPacket) error { - hdr := p.header + hdr := p.hdr // Probably an old packet that was sent by the client before the version was negotiated. // It is safe to drop it. diff --git a/vendor/github.com/lucas-clemente/quic-go/session.go b/vendor/github.com/lucas-clemente/quic-go/session.go index e319cfa0..6ab7746d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/session.go +++ b/vendor/github.com/lucas-clemente/quic-go/session.go @@ -1,6 +1,7 @@ package quic import ( + "bytes" "context" "crypto/tls" "errors" @@ -21,7 +22,7 @@ import ( ) type unpacker interface { - Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) + Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) } type streamGetter interface { @@ -52,7 +53,7 @@ type cryptoStreamHandler interface { type receivedPacket struct { remoteAddr net.Addr - header *wire.Header + hdr *wire.Header data []byte rcvTime time.Time } @@ -113,7 +114,6 @@ type session struct { receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this receivedFirstForwardSecurePacket bool - lastRcvdPacketNumber protocol.PacketNumber // Used to calculate the next packet number from the truncated wire // representation, and sent back in public reset packets largestRcvdPacketNumber protocol.PacketNumber @@ -289,7 +289,7 @@ var newClientSession = func( func (s *session) preSetup() { s.rttStats = &congestion.RTTStats{} - s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version) + s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.InitialMaxData, @@ -374,7 +374,7 @@ runLoop: } // This is a bit unclean, but works properly, since the packet always // begins with the public header and we never copy it. - putPacketBuffer(&p.header.Raw) + // TODO: putPacketBuffer(&p.extHdr.Raw) case <-s.handshakeCompleteChan: s.handleHandshakeComplete() } @@ -479,24 +479,41 @@ func (s *session) handleHandshakeComplete() { } func (s *session) handlePacketImpl(p *receivedPacket) error { - hdr := p.header // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. - if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { - s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.header.SrcConnectionID, s.destConnID) + if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) { + s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID) return nil } - p.rcvTime = time.Now() + data := p.data + r := bytes.NewReader(data) + hdr, err := p.hdr.ParseExtended(r, s.version) + if err != nil { + return fmt.Errorf("error parsing extended header: %s", err) + } + hdr.Raw = data[:len(data)-r.Len()] + data = data[len(data)-r.Len():] + + if hdr.IsLongHeader { + if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) { + return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen) + } + if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length { + return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length) + } + data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)] + // TODO(#1312): implement parsing of compound packets + } + // Calculate packet number hdr.PacketNumber = protocol.InferPacketNumber( hdr.PacketNumberLen, s.largestRcvdPacketNumber, hdr.PacketNumber, - s.version, ) - packet, err := s.unpacker.Unpack(hdr.Raw, hdr, p.data) + packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) if s.logger.Debug() { if err != nil { s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID) @@ -530,7 +547,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { } } - s.lastRcvdPacketNumber = hdr.PacketNumber // Only do this after decrypting, so we are sure the packet is not attacker-controlled s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) @@ -543,10 +559,10 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { } } - return s.handleFrames(packet.frames, packet.encryptionLevel) + return s.handleFrames(packet.frames, hdr.PacketNumber, packet.encryptionLevel) } -func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error { +func (s *session) handleFrames(fs []wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error { for _, ff := range fs { var err error wire.LogFrame(s.logger, ff, false) @@ -556,7 +572,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve case *wire.StreamFrame: err = s.handleStreamFrame(frame, encLevel) case *wire.AckFrame: - err = s.handleAckFrame(frame, encLevel) + err = s.handleAckFrame(frame, pn, encLevel) case *wire.ConnectionCloseFrame: s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase)) case *wire.ResetStreamFrame: @@ -702,8 +718,8 @@ func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data}) } -func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { - if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil { +func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error { + if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastNetworkActivityTime); err != nil { return err } s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked()) @@ -1065,14 +1081,14 @@ func (s *session) scheduleSending() { func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { if s.handshakeComplete { - s.logger.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data)) + s.logger.Debugf("Received undecryptable packet from %s after the handshake (%d bytes)", p.remoteAddr.String(), len(p.data)) return } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { - s.logger.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber) + s.logger.Infof("Dropping undecrytable packet (%d bytes). Undecryptable packet queue full.", len(p.data)) return } - s.logger.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber) + s.logger.Infof("Queueing packet (%d bytes) for later decryption", len(p.data)) s.undecryptablePackets = append(s.undecryptablePackets, p) }