mirror of https://github.com/v2ray/v2ray-core
update quic vendor
parent
90ab42b1cb
commit
135bf169c0
|
@ -27,7 +27,7 @@ type client struct {
|
|||
|
||||
token []byte
|
||||
|
||||
versionNegotiated bool // has the server accepted our version
|
||||
versionNegotiated utils.AtomicBool // has the server accepted our version
|
||||
receivedVersionNegotiationPacket bool
|
||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||
|
||||
|
@ -59,6 +59,7 @@ var (
|
|||
)
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC session is closed.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddr(
|
||||
addr string,
|
||||
|
@ -69,7 +70,7 @@ func DialAddr(
|
|||
}
|
||||
|
||||
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
// See DialAddr for details.
|
||||
func DialAddrContext(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
|
@ -88,6 +89,8 @@ func DialAddrContext(
|
|||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The same PacketConn can be used for multiple calls to Dial and Listen,
|
||||
// QUIC connection IDs are used for demultiplexing the different connections.
|
||||
// The host parameter is used for SNI.
|
||||
func Dial(
|
||||
pconn net.PacketConn,
|
||||
|
@ -100,7 +103,7 @@ func Dial(
|
|||
}
|
||||
|
||||
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
|
||||
// The host parameter is used for SNI.
|
||||
// See Dial for details.
|
||||
func DialContext(
|
||||
ctx context.Context,
|
||||
pconn net.PacketConn,
|
||||
|
@ -164,7 +167,18 @@ func newClient(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
srcConnID, err := generateConnectionID(config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
destConnID, err := generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &client{
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
createdPacketConn: createdPacketConn,
|
||||
tlsConf: tlsConf,
|
||||
|
@ -173,7 +187,7 @@ func newClient(
|
|||
handshakeChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
return c, c.generateConnectionIDs()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
|
@ -234,20 +248,6 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *client) generateConnectionIDs() error {
|
||||
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destConnID, err := generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.srcConnID = srcConnID
|
||||
c.destConnID = destConnID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
|
@ -292,65 +292,49 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (c *client) handlePacket(p *receivedPacket) {
|
||||
if err := c.handlePacketImpl(p); err != nil {
|
||||
c.logger.Errorf("error handling packet: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handlePacketImpl(p *receivedPacket) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// handle Version Negotiation Packets
|
||||
if p.header.IsVersionNegotiation {
|
||||
err := c.handleVersionNegotiationPacket(p.header)
|
||||
if err != nil {
|
||||
c.session.destroy(err)
|
||||
}
|
||||
// version negotiation packets have no payload
|
||||
return err
|
||||
if p.hdr.IsVersionNegotiation() {
|
||||
go c.handleVersionNegotiationPacket(p.hdr)
|
||||
return
|
||||
}
|
||||
|
||||
// reject packets with the wrong connection ID
|
||||
if !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
|
||||
}
|
||||
|
||||
if p.header.Type == protocol.PacketTypeRetry {
|
||||
c.handleRetryPacket(p.header)
|
||||
return nil
|
||||
if p.hdr.Type == protocol.PacketTypeRetry {
|
||||
go c.handleRetryPacket(p.hdr)
|
||||
return
|
||||
}
|
||||
|
||||
// this is the first packet we are receiving
|
||||
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||
if !c.versionNegotiated {
|
||||
c.versionNegotiated = true
|
||||
if !c.versionNegotiated.Get() {
|
||||
c.versionNegotiated.Set(true)
|
||||
}
|
||||
|
||||
c.session.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||
c.logger.Debugf("Received a delayed Version Negotiation Packet.")
|
||||
return nil
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
|
||||
c.logger.Debugf("Received a delayed Version Negotiation packet.")
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range hdr.SupportedVersions {
|
||||
if v == c.version {
|
||||
// the version negotiation packet contains the version that we offered
|
||||
// this might be a packet sent by an attacker (or by a terribly broken server implementation)
|
||||
// ignore it
|
||||
return nil
|
||||
// The Version Negotiation packet contains the version that we offered.
|
||||
// This might be a packet sent by an attacker (or by a terribly broken server implementation).
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||
c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if !ok {
|
||||
return qerr.InvalidVersion
|
||||
c.session.destroy(qerr.InvalidVersion)
|
||||
c.logger.Debugf("No compatible version found.")
|
||||
return
|
||||
}
|
||||
c.receivedVersionNegotiationPacket = true
|
||||
c.negotiatedVersions = hdr.SupportedVersions
|
||||
|
@ -358,18 +342,17 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||
// switch to negotiated version
|
||||
c.initialVersion = c.version
|
||||
c.version = newVersion
|
||||
if err := c.generateConnectionIDs(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
||||
c.session.destroy(errCloseSessionForNewVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleRetryPacket(hdr *wire.Header) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.logger.Debugf("<- Received Retry")
|
||||
hdr.Log(c.logger)
|
||||
(&wire.ExtendedHeader{Header: *hdr}).Log(c.logger)
|
||||
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
|
||||
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
|
||||
return
|
||||
|
|
7
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
7
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
|
@ -75,12 +75,10 @@ type sentPacketHandler struct {
|
|||
alarm time.Time
|
||||
|
||||
logger utils.Logger
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
// NewSentPacketHandler creates a new sentPacketHandler
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
|
||||
congestion := congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
rttStats,
|
||||
|
@ -95,7 +93,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
|
|||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -516,7 +513,7 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
|
|||
|
||||
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||
pn := h.packetNumberGenerator.Peek()
|
||||
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
|
||||
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked())
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
||||
|
|
|
@ -11,11 +11,13 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
|
|||
c = &tls.Config{}
|
||||
}
|
||||
// QUIC requires TLS 1.3 or newer
|
||||
if c.MinVersion < qtls.VersionTLS13 {
|
||||
c.MinVersion = qtls.VersionTLS13
|
||||
minVersion := c.MinVersion
|
||||
if minVersion < qtls.VersionTLS13 {
|
||||
minVersion = qtls.VersionTLS13
|
||||
}
|
||||
if c.MaxVersion < qtls.VersionTLS13 {
|
||||
c.MaxVersion = qtls.VersionTLS13
|
||||
maxVersion := c.MaxVersion
|
||||
if maxVersion < qtls.VersionTLS13 {
|
||||
maxVersion = qtls.VersionTLS13
|
||||
}
|
||||
return &qtls.Config{
|
||||
Rand: c.Rand,
|
||||
|
@ -38,8 +40,8 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
|
|||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||
SessionTicketKey: c.SessionTicketKey,
|
||||
MinVersion: c.MinVersion,
|
||||
MaxVersion: c.MaxVersion,
|
||||
MinVersion: minVersion,
|
||||
MaxVersion: maxVersion,
|
||||
CurvePreferences: c.CurvePreferences,
|
||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||
Renegotiation: c.Renegotiation,
|
||||
|
|
|
@ -1,20 +1,37 @@
|
|||
package protocol
|
||||
|
||||
// PacketNumberLen is the length of the packet number in bytes
|
||||
type PacketNumberLen uint8
|
||||
|
||||
const (
|
||||
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
|
||||
PacketNumberLenInvalid PacketNumberLen = 0
|
||||
// PacketNumberLen1 is a packet number length of 1 byte
|
||||
PacketNumberLen1 PacketNumberLen = 1
|
||||
// PacketNumberLen2 is a packet number length of 2 bytes
|
||||
PacketNumberLen2 PacketNumberLen = 2
|
||||
// PacketNumberLen3 is a packet number length of 3 bytes
|
||||
PacketNumberLen3 PacketNumberLen = 3
|
||||
// PacketNumberLen4 is a packet number length of 4 bytes
|
||||
PacketNumberLen4 PacketNumberLen = 4
|
||||
)
|
||||
|
||||
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
|
||||
func InferPacketNumber(
|
||||
packetNumberLength PacketNumberLen,
|
||||
lastPacketNumber PacketNumber,
|
||||
wirePacketNumber PacketNumber,
|
||||
version VersionNumber,
|
||||
) PacketNumber {
|
||||
var epochDelta PacketNumber
|
||||
switch packetNumberLength {
|
||||
case PacketNumberLen1:
|
||||
epochDelta = PacketNumber(1) << 7
|
||||
epochDelta = PacketNumber(1) << 8
|
||||
case PacketNumberLen2:
|
||||
epochDelta = PacketNumber(1) << 14
|
||||
epochDelta = PacketNumber(1) << 16
|
||||
case PacketNumberLen3:
|
||||
epochDelta = PacketNumber(1) << 24
|
||||
case PacketNumberLen4:
|
||||
epochDelta = PacketNumber(1) << 30
|
||||
epochDelta = PacketNumber(1) << 32
|
||||
}
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
prevEpochBegin := epoch - epochDelta
|
||||
|
@ -42,11 +59,14 @@ func delta(a, b PacketNumber) PacketNumber {
|
|||
|
||||
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
|
||||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
|
||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
|
||||
diff := uint64(packetNumber - leastUnacked)
|
||||
if diff < (1 << (14 - 1)) {
|
||||
if diff < (1 << (16 - 1)) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
if diff < (1 << (24 - 1)) {
|
||||
return PacketNumberLen3
|
||||
}
|
||||
return PacketNumberLen4
|
||||
}
|
||||
|
||||
|
@ -58,5 +78,8 @@ func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen {
|
|||
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
if packetNumber < (1 << (uint8(PacketNumberLen3) * 8)) {
|
||||
return PacketNumberLen3
|
||||
}
|
||||
return PacketNumberLen4
|
||||
}
|
||||
|
|
|
@ -7,32 +7,18 @@ import (
|
|||
// A PacketNumber in QUIC
|
||||
type PacketNumber uint64
|
||||
|
||||
// PacketNumberLen is the length of the packet number in bytes
|
||||
type PacketNumberLen uint8
|
||||
|
||||
const (
|
||||
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
|
||||
PacketNumberLenInvalid PacketNumberLen = 0
|
||||
// PacketNumberLen1 is a packet number length of 1 byte
|
||||
PacketNumberLen1 PacketNumberLen = 1
|
||||
// PacketNumberLen2 is a packet number length of 2 bytes
|
||||
PacketNumberLen2 PacketNumberLen = 2
|
||||
// PacketNumberLen4 is a packet number length of 4 bytes
|
||||
PacketNumberLen4 PacketNumberLen = 4
|
||||
)
|
||||
|
||||
// The PacketType is the Long Header Type
|
||||
type PacketType uint8
|
||||
|
||||
const (
|
||||
// PacketTypeInitial is the packet type of an Initial packet
|
||||
PacketTypeInitial PacketType = 0x7f
|
||||
PacketTypeInitial PacketType = 1 + iota
|
||||
// PacketTypeRetry is the packet type of a Retry packet
|
||||
PacketTypeRetry PacketType = 0x7e
|
||||
PacketTypeRetry
|
||||
// PacketTypeHandshake is the packet type of a Handshake packet
|
||||
PacketTypeHandshake PacketType = 0x7d
|
||||
PacketTypeHandshake
|
||||
// PacketType0RTT is the packet type of a 0-RTT packet
|
||||
PacketType0RTT PacketType = 0x7c
|
||||
PacketType0RTT
|
||||
)
|
||||
|
||||
func (t PacketType) String() string {
|
||||
|
@ -72,11 +58,5 @@ const DefaultTCPMSS ByteCount = 1460
|
|||
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
|
||||
const MinInitialPacketSize = 1200
|
||||
|
||||
// MaxClientHellos is the maximum number of times we'll send a client hello
|
||||
// The value 3 accounts for:
|
||||
// * one failure due to an incorrect or missing source-address token
|
||||
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
|
||||
const MaxClientHellos = 3
|
||||
|
||||
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
|
||||
const MinConnectionIDLenInitial = 8
|
||||
|
|
|
@ -8,11 +8,10 @@ import (
|
|||
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
|
||||
type ByteOrder interface {
|
||||
ReadUintN(b io.ByteReader, length uint8) (uint64, error)
|
||||
ReadUint64(io.ByteReader) (uint64, error)
|
||||
ReadUint32(io.ByteReader) (uint32, error)
|
||||
ReadUint16(io.ByteReader) (uint16, error)
|
||||
|
||||
WriteUint64(*bytes.Buffer, uint64)
|
||||
WriteUintN(b *bytes.Buffer, length uint8, value uint64)
|
||||
WriteUint32(*bytes.Buffer, uint32)
|
||||
WriteUint16(*bytes.Buffer, uint16)
|
||||
}
|
||||
|
|
41
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
41
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
|
@ -25,37 +25,6 @@ func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
|
|||
return res, nil
|
||||
}
|
||||
|
||||
// ReadUint64 reads a uint64
|
||||
func (bigEndian) ReadUint64(b io.ByteReader) (uint64, error) {
|
||||
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
|
||||
var err error
|
||||
if b8, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b7, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b6, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b5, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b4, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b3, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
|
||||
}
|
||||
|
||||
// ReadUint32 reads a uint32
|
||||
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
|
||||
var b1, b2, b3, b4 uint8
|
||||
|
@ -88,12 +57,10 @@ func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
|
|||
return uint16(b1) + uint16(b2)<<8, nil
|
||||
}
|
||||
|
||||
// WriteUint64 writes a uint64
|
||||
func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
|
||||
b.Write([]byte{
|
||||
uint8(i >> 56), uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
|
||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
||||
})
|
||||
func (bigEndian) WriteUintN(b *bytes.Buffer, length uint8, i uint64) {
|
||||
for j := length; j > 0; j-- {
|
||||
b.WriteByte(uint8(i >> (8 * (j - 1))))
|
||||
}
|
||||
}
|
||||
|
||||
// WriteUint32 writes a uint32
|
||||
|
|
205
vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go
generated
vendored
Normal file
205
vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go
generated
vendored
Normal file
|
@ -0,0 +1,205 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// ExtendedHeader is the header of a QUIC packet.
|
||||
type ExtendedHeader struct {
|
||||
Header
|
||||
|
||||
typeByte byte
|
||||
Raw []byte
|
||||
|
||||
PacketNumberLen protocol.PacketNumberLen
|
||||
PacketNumber protocol.PacketNumber
|
||||
|
||||
KeyPhase int
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
// read the (now unencrypted) first byte
|
||||
var err error
|
||||
h.typeByte, err = b.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := b.Seek(int64(h.len)-1, io.SeekCurrent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if h.IsLongHeader {
|
||||
return h.parseLongHeader(b, v)
|
||||
}
|
||||
return h.parseShortHeader(b, v)
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
if h.typeByte&0xc != 0 {
|
||||
return nil, errors.New("5th and 6th bit must be 0")
|
||||
}
|
||||
if err := h.readPacketNumber(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
if h.typeByte&0x18 != 0 {
|
||||
return nil, errors.New("4th and 5th bit must be 0")
|
||||
}
|
||||
|
||||
h.KeyPhase = int(h.typeByte&0x4) >> 2
|
||||
|
||||
if err := h.readPacketNumber(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
|
||||
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
|
||||
pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.PacketNumber = protocol.PacketNumber(pn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes the Header.
|
||||
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
|
||||
if h.IsLongHeader {
|
||||
return h.writeLongHeader(b, ver)
|
||||
}
|
||||
return h.writeShortHeader(b, ver)
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
||||
var packetType uint8
|
||||
switch h.Type {
|
||||
case protocol.PacketTypeInitial:
|
||||
packetType = 0x0
|
||||
case protocol.PacketType0RTT:
|
||||
packetType = 0x1
|
||||
case protocol.PacketTypeHandshake:
|
||||
packetType = 0x2
|
||||
case protocol.PacketTypeRetry:
|
||||
packetType = 0x3
|
||||
}
|
||||
firstByte := 0xc0 | packetType<<4
|
||||
if h.Type == protocol.PacketTypeRetry {
|
||||
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
firstByte |= odcil
|
||||
} else { // Retry packets don't have a packet number
|
||||
firstByte |= uint8(h.PacketNumberLen - 1)
|
||||
}
|
||||
|
||||
b.WriteByte(firstByte)
|
||||
utils.BigEndian.WriteUint32(b, uint32(h.Version))
|
||||
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.WriteByte(connIDLen)
|
||||
b.Write(h.DestConnectionID.Bytes())
|
||||
b.Write(h.SrcConnectionID.Bytes())
|
||||
|
||||
switch h.Type {
|
||||
case protocol.PacketTypeRetry:
|
||||
b.Write(h.OrigDestConnectionID.Bytes())
|
||||
b.Write(h.Token)
|
||||
return nil
|
||||
case protocol.PacketTypeInitial:
|
||||
utils.WriteVarInt(b, uint64(len(h.Token)))
|
||||
b.Write(h.Token)
|
||||
}
|
||||
|
||||
utils.WriteVarInt(b, uint64(h.Length))
|
||||
return h.writePacketNumber(b)
|
||||
}
|
||||
|
||||
// TODO: add support for the key phase
|
||||
func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
||||
typeByte := 0x40 | uint8(h.PacketNumberLen-1)
|
||||
typeByte |= byte(h.KeyPhase << 2)
|
||||
|
||||
b.WriteByte(typeByte)
|
||||
b.Write(h.DestConnectionID.Bytes())
|
||||
return h.writePacketNumber(b)
|
||||
}
|
||||
|
||||
func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
|
||||
if h.PacketNumberLen == protocol.PacketNumberLenInvalid || h.PacketNumberLen > protocol.PacketNumberLen4 {
|
||||
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
|
||||
}
|
||||
utils.BigEndian.WriteUintN(b, uint8(h.PacketNumberLen), uint64(h.PacketNumber))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLength determines the length of the Header.
|
||||
func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
|
||||
if h.IsLongHeader {
|
||||
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
||||
}
|
||||
return length
|
||||
}
|
||||
|
||||
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
|
||||
length += protocol.ByteCount(h.PacketNumberLen)
|
||||
return length
|
||||
}
|
||||
|
||||
// Log logs the Header
|
||||
func (h *ExtendedHeader) Log(logger utils.Logger) {
|
||||
if h.IsLongHeader {
|
||||
var token string
|
||||
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
|
||||
if len(h.Token) == 0 {
|
||||
token = "Token: (empty), "
|
||||
} else {
|
||||
token = fmt.Sprintf("Token: %#x, ", h.Token)
|
||||
}
|
||||
if h.Type == protocol.PacketTypeRetry {
|
||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
|
||||
return
|
||||
}
|
||||
}
|
||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
|
||||
} else {
|
||||
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
|
||||
}
|
||||
}
|
||||
|
||||
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
|
||||
dcil, err := encodeSingleConnIDLen(dest)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
scil, err := encodeSingleConnIDLen(src)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return scil | dcil<<4, nil
|
||||
}
|
||||
|
||||
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
|
||||
len := id.Len()
|
||||
if len == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len < 4 || len > 18 {
|
||||
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
|
||||
}
|
||||
return byte(len - 3), nil
|
||||
}
|
|
@ -2,150 +2,183 @@ package wire
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// Header is the header of a QUIC packet.
|
||||
// The Header is the version independent part of the header
|
||||
type Header struct {
|
||||
Raw []byte
|
||||
Version protocol.VersionNumber
|
||||
SrcConnectionID protocol.ConnectionID
|
||||
DestConnectionID protocol.ConnectionID
|
||||
|
||||
Version protocol.VersionNumber
|
||||
|
||||
DestConnectionID protocol.ConnectionID
|
||||
SrcConnectionID protocol.ConnectionID
|
||||
OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
|
||||
|
||||
PacketNumberLen protocol.PacketNumberLen
|
||||
PacketNumber protocol.PacketNumber
|
||||
|
||||
IsVersionNegotiation bool
|
||||
SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
|
||||
|
||||
Type protocol.PacketType
|
||||
IsLongHeader bool
|
||||
KeyPhase int
|
||||
Type protocol.PacketType
|
||||
Length protocol.ByteCount
|
||||
Token []byte
|
||||
|
||||
Token []byte
|
||||
SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet
|
||||
OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet
|
||||
|
||||
typeByte byte
|
||||
len int // how many bytes were read while parsing this header
|
||||
}
|
||||
|
||||
// Write writes the Header.
|
||||
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
|
||||
if h.IsLongHeader {
|
||||
return h.writeLongHeader(b, ver)
|
||||
// ParseHeader parses the header.
|
||||
// For short header packets: up to the packet number.
|
||||
// For long header packets:
|
||||
// * if we understand the version: up to the packet number
|
||||
// * if not, only the invariant part of the header
|
||||
func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
|
||||
startLen := b.Len()
|
||||
h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.writeShortHeader(b, ver)
|
||||
h.len = startLen - b.Len()
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// TODO: add support for the key phase
|
||||
func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
||||
b.WriteByte(byte(0x80 | h.Type))
|
||||
utils.BigEndian.WriteUint32(b, uint32(h.Version))
|
||||
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
|
||||
func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
|
||||
typeByte, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h := &Header{
|
||||
typeByte: typeByte,
|
||||
IsLongHeader: typeByte&0x80 > 0,
|
||||
}
|
||||
|
||||
if !h.IsLongHeader {
|
||||
if h.typeByte&0x40 == 0 {
|
||||
return nil, errors.New("not a QUIC packet")
|
||||
}
|
||||
if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
if err := h.parseLongHeader(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
|
||||
var err error
|
||||
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Header) parseLongHeader(b *bytes.Reader) error {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.WriteByte(connIDLen)
|
||||
b.Write(h.DestConnectionID.Bytes())
|
||||
b.Write(h.SrcConnectionID.Bytes())
|
||||
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
utils.WriteVarInt(b, uint64(len(h.Token)))
|
||||
b.Write(h.Token)
|
||||
h.Version = protocol.VersionNumber(v)
|
||||
if !h.IsVersionNegotiation() && h.typeByte&0x40 == 0 {
|
||||
return errors.New("not a QUIC packet")
|
||||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeRetry {
|
||||
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// randomize the first 4 bits
|
||||
odcilByte := make([]byte, 1)
|
||||
_, _ = rand.Read(odcilByte) // it's safe to ignore the error here
|
||||
odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
|
||||
b.Write(odcilByte)
|
||||
b.Write(h.OrigDestConnectionID.Bytes())
|
||||
b.Write(h.Token)
|
||||
connIDLenByte, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dcil, scil := decodeConnIDLen(connIDLenByte)
|
||||
h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if h.Version == 0 {
|
||||
return h.parseVersionNegotiationPacket(b)
|
||||
}
|
||||
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
|
||||
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
|
||||
return nil
|
||||
}
|
||||
|
||||
utils.WriteVarInt(b, uint64(h.Length))
|
||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||
}
|
||||
switch (h.typeByte & 0x30) >> 4 {
|
||||
case 0x0:
|
||||
h.Type = protocol.PacketTypeInitial
|
||||
case 0x1:
|
||||
h.Type = protocol.PacketType0RTT
|
||||
case 0x2:
|
||||
h.Type = protocol.PacketTypeHandshake
|
||||
case 0x3:
|
||||
h.Type = protocol.PacketTypeRetry
|
||||
}
|
||||
|
||||
func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
||||
typeByte := byte(0x30)
|
||||
typeByte |= byte(h.KeyPhase << 6)
|
||||
|
||||
b.WriteByte(typeByte)
|
||||
b.Write(h.DestConnectionID.Bytes())
|
||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||
}
|
||||
|
||||
// GetLength determines the length of the Header.
|
||||
func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount {
|
||||
if h.IsLongHeader {
|
||||
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
||||
if h.Type == protocol.PacketTypeRetry {
|
||||
odcil := decodeSingleConnIDLen(h.typeByte & 0xf)
|
||||
h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return length
|
||||
}
|
||||
|
||||
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
|
||||
length += protocol.ByteCount(h.PacketNumberLen)
|
||||
return length
|
||||
}
|
||||
|
||||
// Log logs the Header
|
||||
func (h *Header) Log(logger utils.Logger) {
|
||||
if h.IsLongHeader {
|
||||
if h.Version == 0 {
|
||||
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
|
||||
} else {
|
||||
var token string
|
||||
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
|
||||
if len(h.Token) == 0 {
|
||||
token = "Token: (empty), "
|
||||
} else {
|
||||
token = fmt.Sprintf("Token: %#x, ", h.Token)
|
||||
}
|
||||
}
|
||||
if h.Type == protocol.PacketTypeRetry {
|
||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
|
||||
return
|
||||
}
|
||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
|
||||
h.Token = make([]byte, b.Len())
|
||||
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
|
||||
return nil
|
||||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
tokenLen, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tokenLen > uint64(b.Len()) {
|
||||
return io.EOF
|
||||
}
|
||||
h.Token = make([]byte, tokenLen)
|
||||
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
pl, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.Length = protocol.ByteCount(pl)
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
|
||||
dcil, err := encodeSingleConnIDLen(dest)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error {
|
||||
if b.Len() == 0 {
|
||||
return qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
|
||||
}
|
||||
scil, err := encodeSingleConnIDLen(src)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
|
||||
for i := 0; b.Len() > 0; i++ {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil {
|
||||
return qerr.InvalidVersionNegotiationPacket
|
||||
}
|
||||
h.SupportedVersions[i] = protocol.VersionNumber(v)
|
||||
}
|
||||
return scil | dcil<<4, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
|
||||
len := id.Len()
|
||||
if len == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len < 4 || len > 18 {
|
||||
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
|
||||
}
|
||||
return byte(len - 3), nil
|
||||
// IsVersionNegotiation says if this a version negotiation packet
|
||||
func (h *Header) IsVersionNegotiation() bool {
|
||||
return h.IsLongHeader && h.Version == 0
|
||||
}
|
||||
|
||||
// ParseExtended parses the version dependent part of the header.
|
||||
// The Reader has to be set such that it points to the first byte of the header.
|
||||
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||
return h.toExtendedHeader().parse(b, ver)
|
||||
}
|
||||
|
||||
func (h *Header) toExtendedHeader() *ExtendedHeader {
|
||||
return &ExtendedHeader{Header: *h}
|
||||
}
|
||||
|
||||
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
|
||||
|
|
|
@ -162,75 +162,46 @@ func (h *packetHandlerMap) listen() {
|
|||
}
|
||||
|
||||
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
||||
rcvTime := time.Now()
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen)
|
||||
hdr, err := wire.ParseHeader(r, h.connIDLen)
|
||||
// drop the packet if we can't parse the header
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing invariant header: %s", err)
|
||||
}
|
||||
|
||||
h.mutex.RLock()
|
||||
handlerEntry, handlerFound := h.handlers[string(iHdr.DestConnectionID)]
|
||||
server := h.server
|
||||
|
||||
var sentBy protocol.Perspective
|
||||
var version protocol.VersionNumber
|
||||
var handlePacket func(*receivedPacket)
|
||||
if handlerFound { // existing session
|
||||
handler := handlerEntry.handler
|
||||
sentBy = handler.GetPerspective().Opposite()
|
||||
version = handler.GetVersion()
|
||||
handlePacket = handler.handlePacket
|
||||
} else { // no session found
|
||||
// this might be a stateless reset
|
||||
if !iHdr.IsLongHeader {
|
||||
if len(data) >= protocol.MinStatelessResetSize {
|
||||
var token [16]byte
|
||||
copy(token[:], data[len(data)-16:])
|
||||
if sess, ok := h.resetTokens[token]; ok {
|
||||
h.mutex.RUnlock()
|
||||
sess.destroy(errors.New("received a stateless reset"))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// TODO(#943): send a stateless reset
|
||||
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
||||
}
|
||||
if server == nil { // no server set
|
||||
h.mutex.RUnlock()
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
||||
}
|
||||
handlePacket = server.handlePacket
|
||||
sentBy = protocol.PerspectiveClient
|
||||
version = iHdr.Version
|
||||
}
|
||||
h.mutex.RUnlock()
|
||||
|
||||
hdr, err := iHdr.Parse(r, sentBy, version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing header: %s", err)
|
||||
}
|
||||
hdr.Raw = data[:len(data)-r.Len()]
|
||||
packetData := data[len(data)-r.Len():]
|
||||
|
||||
if hdr.IsLongHeader {
|
||||
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
|
||||
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
|
||||
}
|
||||
if protocol.ByteCount(len(packetData))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
|
||||
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(packetData)+int(hdr.PacketNumberLen), hdr.Length)
|
||||
}
|
||||
packetData = packetData[:int(hdr.Length)-int(hdr.PacketNumberLen)]
|
||||
// TODO(#1312): implement parsing of compound packets
|
||||
p := &receivedPacket{
|
||||
remoteAddr: addr,
|
||||
hdr: hdr,
|
||||
data: data,
|
||||
rcvTime: time.Now(),
|
||||
}
|
||||
|
||||
handlePacket(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
data: packetData,
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)]
|
||||
|
||||
if handlerFound { // existing session
|
||||
handlerEntry.handler.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
// No session found.
|
||||
// This might be a stateless reset.
|
||||
if !hdr.IsLongHeader {
|
||||
if len(data) >= protocol.MinStatelessResetSize {
|
||||
var token [16]byte
|
||||
copy(token[:], data[len(data)-16:])
|
||||
if sess, ok := h.resetTokens[token]; ok {
|
||||
sess.destroy(errors.New("received a stateless reset"))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// TODO(#943): send a stateless reset
|
||||
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID)
|
||||
}
|
||||
if h.server == nil { // no server set
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
|
||||
}
|
||||
h.server.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ type packer interface {
|
|||
}
|
||||
|
||||
type packedPacket struct {
|
||||
header *wire.Header
|
||||
header *wire.ExtendedHeader
|
||||
raw []byte
|
||||
frames []wire.Frame
|
||||
encryptionLevel protocol.EncryptionLevel
|
||||
|
@ -397,14 +397,13 @@ func (p *packetPacker) composeNextPacket(
|
|||
return frames, nil
|
||||
}
|
||||
|
||||
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header {
|
||||
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
|
||||
pn, pnLen := p.pnManager.PeekPacketNumber()
|
||||
header := &wire.Header{
|
||||
PacketNumber: pn,
|
||||
PacketNumberLen: pnLen,
|
||||
Version: p.version,
|
||||
DestConnectionID: p.destConnID,
|
||||
}
|
||||
header := &wire.ExtendedHeader{}
|
||||
header.PacketNumber = pn
|
||||
header.PacketNumberLen = pnLen
|
||||
header.Version = p.version
|
||||
header.DestConnectionID = p.destConnID
|
||||
|
||||
if encLevel != protocol.Encryption1RTT {
|
||||
header.IsLongHeader = true
|
||||
|
@ -424,8 +423,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
|
|||
}
|
||||
|
||||
func (p *packetPacker) writeAndSealPacket(
|
||||
header *wire.Header,
|
||||
frames []wire.Frame,
|
||||
header *wire.ExtendedHeader, frames []wire.Frame,
|
||||
sealer handshake.Sealer,
|
||||
) ([]byte, error) {
|
||||
raw := *getPacketBuffer()
|
||||
|
@ -450,7 +448,7 @@ func (p *packetPacker) writeAndSealPacket(
|
|||
}
|
||||
}
|
||||
|
||||
if err := header.Write(buffer, p.perspective, p.version); err != nil {
|
||||
if err := header.Write(buffer, p.version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadStartIndex := buffer.Len()
|
||||
|
|
|
@ -35,7 +35,7 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
|
|||
}
|
||||
}
|
||||
|
||||
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
||||
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
|
||||
buf := *getPacketBuffer()
|
||||
buf = buf[:0]
|
||||
defer putPacketBuffer(&buf)
|
||||
|
|
|
@ -21,7 +21,6 @@ type packetHandler interface {
|
|||
handlePacket(*receivedPacket)
|
||||
io.Closer
|
||||
destroy(error)
|
||||
GetVersion() protocol.VersionNumber
|
||||
GetPerspective() protocol.Perspective
|
||||
}
|
||||
|
||||
|
@ -99,7 +98,8 @@ var _ Listener = &server{}
|
|||
var _ unknownPacketHandler = &server{}
|
||||
|
||||
// ListenAddr creates a QUIC server listening on a given address.
|
||||
// The tls.Config must not be nil, the quic.Config may be nil.
|
||||
// The tls.Config must not be nil and must contain a certificate configuration.
|
||||
// The quic.Config may be nil, in that case the default values will be used.
|
||||
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
|
@ -118,7 +118,11 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err
|
|||
}
|
||||
|
||||
// Listen listens for QUIC connections on a given net.PacketConn.
|
||||
// The tls.Config must not be nil, the quic.Config may be nil.
|
||||
// A single PacketConn only be used for a single call to Listen.
|
||||
// The PacketConn can be used for simultaneous calls to Dial.
|
||||
// QUIC connection IDs are used for demultiplexing the different connections.
|
||||
// The tls.Config must not be nil and must contain a certificate configuration.
|
||||
// The quic.Config may be nil, in that case the default values will be used.
|
||||
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
|
||||
return listen(conn, tlsConf, config)
|
||||
}
|
||||
|
@ -300,23 +304,17 @@ func (s *server) Addr() net.Addr {
|
|||
}
|
||||
|
||||
func (s *server) handlePacket(p *receivedPacket) {
|
||||
if err := s.handlePacketImpl(p); err != nil {
|
||||
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) handlePacketImpl(p *receivedPacket) error {
|
||||
hdr := p.header
|
||||
hdr := p.hdr
|
||||
|
||||
// send a Version Negotiation Packet if the client is speaking a different protocol version
|
||||
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
||||
return s.sendVersionNegotiationPacket(p)
|
||||
go s.sendVersionNegotiationPacket(p)
|
||||
return
|
||||
}
|
||||
if hdr.Type == protocol.PacketTypeInitial {
|
||||
go s.handleInitial(p)
|
||||
}
|
||||
// TODO(#943): send Stateless Reset
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *server) handleInitial(p *receivedPacket) {
|
||||
|
@ -335,11 +333,11 @@ func (s *server) handleInitial(p *receivedPacket) {
|
|||
}
|
||||
|
||||
func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
|
||||
hdr := p.header
|
||||
hdr := p.hdr
|
||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
return nil, nil, errors.New("dropping Initial packet with too short connection ID")
|
||||
}
|
||||
if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize {
|
||||
if len(p.data) < protocol.MinInitialPacketSize {
|
||||
return nil, nil, errors.New("dropping too small Initial packet")
|
||||
}
|
||||
|
||||
|
@ -358,7 +356,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con
|
|||
if !s.config.AcceptCookie(p.remoteAddr, cookie) {
|
||||
// Log the Initial packet now.
|
||||
// If no Retry is sent, the packet will be logged by the session.
|
||||
p.header.Log(s.logger)
|
||||
(&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger)
|
||||
return nil, nil, s.sendRetry(p.remoteAddr, hdr)
|
||||
}
|
||||
|
||||
|
@ -431,19 +429,18 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
replyHdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
Version: hdr.Version,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: hdr.SrcConnectionID,
|
||||
OrigDestConnectionID: hdr.DestConnectionID,
|
||||
Token: token,
|
||||
}
|
||||
replyHdr := &wire.ExtendedHeader{}
|
||||
replyHdr.IsLongHeader = true
|
||||
replyHdr.Type = protocol.PacketTypeRetry
|
||||
replyHdr.Version = hdr.Version
|
||||
replyHdr.SrcConnectionID = connID
|
||||
replyHdr.DestConnectionID = hdr.SrcConnectionID
|
||||
replyHdr.OrigDestConnectionID = hdr.DestConnectionID
|
||||
replyHdr.Token = token
|
||||
s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
|
||||
replyHdr.Log(s.logger)
|
||||
buf := &bytes.Buffer{}
|
||||
if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil {
|
||||
if err := replyHdr.Write(buf, hdr.Version); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
|
||||
|
@ -452,14 +449,15 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
|
||||
hdr := p.header
|
||||
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
|
||||
|
||||
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
|
||||
hdr := p.hdr
|
||||
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
|
||||
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
|
||||
if err != nil {
|
||||
return err
|
||||
s.logger.Debugf("Error composing Version Negotiation: %s", err)
|
||||
return
|
||||
}
|
||||
if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil {
|
||||
s.logger.Debugf("Error sending Version Negotiation: %s", err)
|
||||
}
|
||||
_, err = s.conn.WriteTo(data, p.remoteAddr)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ func (s *serverSession) handlePacket(p *receivedPacket) {
|
|||
}
|
||||
|
||||
func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
|
||||
hdr := p.header
|
||||
hdr := p.hdr
|
||||
|
||||
// Probably an old packet that was sent by the client before the version was negotiated.
|
||||
// It is safe to drop it.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
|
@ -21,7 +22,7 @@ import (
|
|||
)
|
||||
|
||||
type unpacker interface {
|
||||
Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error)
|
||||
Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
|
||||
}
|
||||
|
||||
type streamGetter interface {
|
||||
|
@ -52,7 +53,7 @@ type cryptoStreamHandler interface {
|
|||
|
||||
type receivedPacket struct {
|
||||
remoteAddr net.Addr
|
||||
header *wire.Header
|
||||
hdr *wire.Header
|
||||
data []byte
|
||||
rcvTime time.Time
|
||||
}
|
||||
|
@ -113,7 +114,6 @@ type session struct {
|
|||
|
||||
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
|
||||
receivedFirstForwardSecurePacket bool
|
||||
lastRcvdPacketNumber protocol.PacketNumber
|
||||
// Used to calculate the next packet number from the truncated wire
|
||||
// representation, and sent back in public reset packets
|
||||
largestRcvdPacketNumber protocol.PacketNumber
|
||||
|
@ -289,7 +289,7 @@ var newClientSession = func(
|
|||
|
||||
func (s *session) preSetup() {
|
||||
s.rttStats = &congestion.RTTStats{}
|
||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version)
|
||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
|
||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
|
||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||
protocol.InitialMaxData,
|
||||
|
@ -374,7 +374,7 @@ runLoop:
|
|||
}
|
||||
// This is a bit unclean, but works properly, since the packet always
|
||||
// begins with the public header and we never copy it.
|
||||
putPacketBuffer(&p.header.Raw)
|
||||
// TODO: putPacketBuffer(&p.extHdr.Raw)
|
||||
case <-s.handshakeCompleteChan:
|
||||
s.handleHandshakeComplete()
|
||||
}
|
||||
|
@ -479,24 +479,41 @@ func (s *session) handleHandshakeComplete() {
|
|||
}
|
||||
|
||||
func (s *session) handlePacketImpl(p *receivedPacket) error {
|
||||
hdr := p.header
|
||||
// The server can change the source connection ID with the first Handshake packet.
|
||||
// After this, all packets with a different source connection have to be ignored.
|
||||
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
|
||||
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.header.SrcConnectionID, s.destConnID)
|
||||
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
|
||||
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID)
|
||||
return nil
|
||||
}
|
||||
|
||||
p.rcvTime = time.Now()
|
||||
data := p.data
|
||||
r := bytes.NewReader(data)
|
||||
hdr, err := p.hdr.ParseExtended(r, s.version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing extended header: %s", err)
|
||||
}
|
||||
hdr.Raw = data[:len(data)-r.Len()]
|
||||
data = data[len(data)-r.Len():]
|
||||
|
||||
if hdr.IsLongHeader {
|
||||
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
|
||||
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
|
||||
}
|
||||
if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
|
||||
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length)
|
||||
}
|
||||
data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)]
|
||||
// TODO(#1312): implement parsing of compound packets
|
||||
}
|
||||
|
||||
// Calculate packet number
|
||||
hdr.PacketNumber = protocol.InferPacketNumber(
|
||||
hdr.PacketNumberLen,
|
||||
s.largestRcvdPacketNumber,
|
||||
hdr.PacketNumber,
|
||||
s.version,
|
||||
)
|
||||
|
||||
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, p.data)
|
||||
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
|
||||
if s.logger.Debug() {
|
||||
if err != nil {
|
||||
s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID)
|
||||
|
@ -530,7 +547,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
|
|||
}
|
||||
}
|
||||
|
||||
s.lastRcvdPacketNumber = hdr.PacketNumber
|
||||
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
|
||||
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
|
||||
|
||||
|
@ -543,10 +559,10 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
|
|||
}
|
||||
}
|
||||
|
||||
return s.handleFrames(packet.frames, packet.encryptionLevel)
|
||||
return s.handleFrames(packet.frames, hdr.PacketNumber, packet.encryptionLevel)
|
||||
}
|
||||
|
||||
func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error {
|
||||
func (s *session) handleFrames(fs []wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
|
||||
for _, ff := range fs {
|
||||
var err error
|
||||
wire.LogFrame(s.logger, ff, false)
|
||||
|
@ -556,7 +572,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
|
|||
case *wire.StreamFrame:
|
||||
err = s.handleStreamFrame(frame, encLevel)
|
||||
case *wire.AckFrame:
|
||||
err = s.handleAckFrame(frame, encLevel)
|
||||
err = s.handleAckFrame(frame, pn, encLevel)
|
||||
case *wire.ConnectionCloseFrame:
|
||||
s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
|
||||
case *wire.ResetStreamFrame:
|
||||
|
@ -702,8 +718,8 @@ func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) {
|
|||
s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data})
|
||||
}
|
||||
|
||||
func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error {
|
||||
if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil {
|
||||
func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
|
||||
if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastNetworkActivityTime); err != nil {
|
||||
return err
|
||||
}
|
||||
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
|
||||
|
@ -1065,14 +1081,14 @@ func (s *session) scheduleSending() {
|
|||
|
||||
func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) {
|
||||
if s.handshakeComplete {
|
||||
s.logger.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data))
|
||||
s.logger.Debugf("Received undecryptable packet from %s after the handshake (%d bytes)", p.remoteAddr.String(), len(p.data))
|
||||
return
|
||||
}
|
||||
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
|
||||
s.logger.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber)
|
||||
s.logger.Infof("Dropping undecrytable packet (%d bytes). Undecryptable packet queue full.", len(p.data))
|
||||
return
|
||||
}
|
||||
s.logger.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber)
|
||||
s.logger.Infof("Queueing packet (%d bytes) for later decryption", len(p.data))
|
||||
s.undecryptablePackets = append(s.undecryptablePackets, p)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue