mirror of https://github.com/v2ray/v2ray-core
update quic vendor
parent
90ab42b1cb
commit
135bf169c0
|
@ -27,7 +27,7 @@ type client struct {
|
||||||
|
|
||||||
token []byte
|
token []byte
|
||||||
|
|
||||||
versionNegotiated bool // has the server accepted our version
|
versionNegotiated utils.AtomicBool // has the server accepted our version
|
||||||
receivedVersionNegotiationPacket bool
|
receivedVersionNegotiationPacket bool
|
||||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
// DialAddr establishes a new QUIC connection to a server.
|
// DialAddr establishes a new QUIC connection to a server.
|
||||||
|
// It uses a new UDP connection and closes this connection when the QUIC session is closed.
|
||||||
// The hostname for SNI is taken from the given address.
|
// The hostname for SNI is taken from the given address.
|
||||||
func DialAddr(
|
func DialAddr(
|
||||||
addr string,
|
addr string,
|
||||||
|
@ -69,7 +70,7 @@ func DialAddr(
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
|
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
|
||||||
// The hostname for SNI is taken from the given address.
|
// See DialAddr for details.
|
||||||
func DialAddrContext(
|
func DialAddrContext(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
addr string,
|
addr string,
|
||||||
|
@ -88,6 +89,8 @@ func DialAddrContext(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||||
|
// The same PacketConn can be used for multiple calls to Dial and Listen,
|
||||||
|
// QUIC connection IDs are used for demultiplexing the different connections.
|
||||||
// The host parameter is used for SNI.
|
// The host parameter is used for SNI.
|
||||||
func Dial(
|
func Dial(
|
||||||
pconn net.PacketConn,
|
pconn net.PacketConn,
|
||||||
|
@ -100,7 +103,7 @@ func Dial(
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
|
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
|
||||||
// The host parameter is used for SNI.
|
// See Dial for details.
|
||||||
func DialContext(
|
func DialContext(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
pconn net.PacketConn,
|
pconn net.PacketConn,
|
||||||
|
@ -164,7 +167,18 @@ func newClient(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
srcConnID, err := generateConnectionID(config.ConnectionIDLength)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
destConnID, err := generateConnectionIDForInitial()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
c := &client{
|
c := &client{
|
||||||
|
srcConnID: srcConnID,
|
||||||
|
destConnID: destConnID,
|
||||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||||
createdPacketConn: createdPacketConn,
|
createdPacketConn: createdPacketConn,
|
||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
|
@ -173,7 +187,7 @@ func newClient(
|
||||||
handshakeChan: make(chan struct{}),
|
handshakeChan: make(chan struct{}),
|
||||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||||
}
|
}
|
||||||
return c, c.generateConnectionIDs()
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||||
|
@ -234,20 +248,6 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) generateConnectionIDs() error {
|
|
||||||
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
destConnID, err := generateConnectionIDForInitial()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.srcConnID = srcConnID
|
|
||||||
c.destConnID = destConnID
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) dial(ctx context.Context) error {
|
func (c *client) dial(ctx context.Context) error {
|
||||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||||
|
|
||||||
|
@ -292,65 +292,49 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handlePacket(p *receivedPacket) {
|
func (c *client) handlePacket(p *receivedPacket) {
|
||||||
if err := c.handlePacketImpl(p); err != nil {
|
if p.hdr.IsVersionNegotiation() {
|
||||||
c.logger.Errorf("error handling packet: %s", err)
|
go c.handleVersionNegotiationPacket(p.hdr)
|
||||||
}
|
return
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) handlePacketImpl(p *receivedPacket) error {
|
|
||||||
c.mutex.Lock()
|
|
||||||
defer c.mutex.Unlock()
|
|
||||||
|
|
||||||
// handle Version Negotiation Packets
|
|
||||||
if p.header.IsVersionNegotiation {
|
|
||||||
err := c.handleVersionNegotiationPacket(p.header)
|
|
||||||
if err != nil {
|
|
||||||
c.session.destroy(err)
|
|
||||||
}
|
|
||||||
// version negotiation packets have no payload
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// reject packets with the wrong connection ID
|
if p.hdr.Type == protocol.PacketTypeRetry {
|
||||||
if !p.header.DestConnectionID.Equal(c.srcConnID) {
|
go c.handleRetryPacket(p.hdr)
|
||||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
|
return
|
||||||
}
|
|
||||||
|
|
||||||
if p.header.Type == protocol.PacketTypeRetry {
|
|
||||||
c.handleRetryPacket(p.header)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is the first packet we are receiving
|
// this is the first packet we are receiving
|
||||||
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||||
if !c.versionNegotiated {
|
if !c.versionNegotiated.Get() {
|
||||||
c.versionNegotiated = true
|
c.versionNegotiated.Set(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.session.handlePacket(p)
|
c.session.handlePacket(p)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
// ignore delayed / duplicated version negotiation packets
|
// ignore delayed / duplicated version negotiation packets
|
||||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
|
||||||
c.logger.Debugf("Received a delayed Version Negotiation Packet.")
|
c.logger.Debugf("Received a delayed Version Negotiation packet.")
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range hdr.SupportedVersions {
|
for _, v := range hdr.SupportedVersions {
|
||||||
if v == c.version {
|
if v == c.version {
|
||||||
// the version negotiation packet contains the version that we offered
|
// The Version Negotiation packet contains the version that we offered.
|
||||||
// this might be a packet sent by an attacker (or by a terribly broken server implementation)
|
// This might be a packet sent by an attacker (or by a terribly broken server implementation).
|
||||||
// ignore it
|
return
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
|
c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||||
if !ok {
|
if !ok {
|
||||||
return qerr.InvalidVersion
|
c.session.destroy(qerr.InvalidVersion)
|
||||||
|
c.logger.Debugf("No compatible version found.")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
c.receivedVersionNegotiationPacket = true
|
c.receivedVersionNegotiationPacket = true
|
||||||
c.negotiatedVersions = hdr.SupportedVersions
|
c.negotiatedVersions = hdr.SupportedVersions
|
||||||
|
@ -358,18 +342,17 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||||
// switch to negotiated version
|
// switch to negotiated version
|
||||||
c.initialVersion = c.version
|
c.initialVersion = c.version
|
||||||
c.version = newVersion
|
c.version = newVersion
|
||||||
if err := c.generateConnectionIDs(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
||||||
c.session.destroy(errCloseSessionForNewVersion)
|
c.session.destroy(errCloseSessionForNewVersion)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handleRetryPacket(hdr *wire.Header) {
|
func (c *client) handleRetryPacket(hdr *wire.Header) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
c.logger.Debugf("<- Received Retry")
|
c.logger.Debugf("<- Received Retry")
|
||||||
hdr.Log(c.logger)
|
(&wire.ExtendedHeader{Header: *hdr}).Log(c.logger)
|
||||||
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
|
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
|
||||||
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
|
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
|
||||||
return
|
return
|
||||||
|
|
7
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
7
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
|
@ -75,12 +75,10 @@ type sentPacketHandler struct {
|
||||||
alarm time.Time
|
alarm time.Time
|
||||||
|
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
|
|
||||||
version protocol.VersionNumber
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSentPacketHandler creates a new sentPacketHandler
|
// NewSentPacketHandler creates a new sentPacketHandler
|
||||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
|
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
|
||||||
congestion := congestion.NewCubicSender(
|
congestion := congestion.NewCubicSender(
|
||||||
congestion.DefaultClock{},
|
congestion.DefaultClock{},
|
||||||
rttStats,
|
rttStats,
|
||||||
|
@ -95,7 +93,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
|
||||||
rttStats: rttStats,
|
rttStats: rttStats,
|
||||||
congestion: congestion,
|
congestion: congestion,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
version: version,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -516,7 +513,7 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
|
||||||
|
|
||||||
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||||
pn := h.packetNumberGenerator.Peek()
|
pn := h.packetNumberGenerator.Peek()
|
||||||
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
|
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
||||||
|
|
|
@ -11,11 +11,13 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
|
||||||
c = &tls.Config{}
|
c = &tls.Config{}
|
||||||
}
|
}
|
||||||
// QUIC requires TLS 1.3 or newer
|
// QUIC requires TLS 1.3 or newer
|
||||||
if c.MinVersion < qtls.VersionTLS13 {
|
minVersion := c.MinVersion
|
||||||
c.MinVersion = qtls.VersionTLS13
|
if minVersion < qtls.VersionTLS13 {
|
||||||
|
minVersion = qtls.VersionTLS13
|
||||||
}
|
}
|
||||||
if c.MaxVersion < qtls.VersionTLS13 {
|
maxVersion := c.MaxVersion
|
||||||
c.MaxVersion = qtls.VersionTLS13
|
if maxVersion < qtls.VersionTLS13 {
|
||||||
|
maxVersion = qtls.VersionTLS13
|
||||||
}
|
}
|
||||||
return &qtls.Config{
|
return &qtls.Config{
|
||||||
Rand: c.Rand,
|
Rand: c.Rand,
|
||||||
|
@ -38,8 +40,8 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
|
||||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||||
SessionTicketKey: c.SessionTicketKey,
|
SessionTicketKey: c.SessionTicketKey,
|
||||||
MinVersion: c.MinVersion,
|
MinVersion: minVersion,
|
||||||
MaxVersion: c.MaxVersion,
|
MaxVersion: maxVersion,
|
||||||
CurvePreferences: c.CurvePreferences,
|
CurvePreferences: c.CurvePreferences,
|
||||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||||
Renegotiation: c.Renegotiation,
|
Renegotiation: c.Renegotiation,
|
||||||
|
|
|
@ -1,20 +1,37 @@
|
||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
|
// PacketNumberLen is the length of the packet number in bytes
|
||||||
|
type PacketNumberLen uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
|
||||||
|
PacketNumberLenInvalid PacketNumberLen = 0
|
||||||
|
// PacketNumberLen1 is a packet number length of 1 byte
|
||||||
|
PacketNumberLen1 PacketNumberLen = 1
|
||||||
|
// PacketNumberLen2 is a packet number length of 2 bytes
|
||||||
|
PacketNumberLen2 PacketNumberLen = 2
|
||||||
|
// PacketNumberLen3 is a packet number length of 3 bytes
|
||||||
|
PacketNumberLen3 PacketNumberLen = 3
|
||||||
|
// PacketNumberLen4 is a packet number length of 4 bytes
|
||||||
|
PacketNumberLen4 PacketNumberLen = 4
|
||||||
|
)
|
||||||
|
|
||||||
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
|
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
|
||||||
func InferPacketNumber(
|
func InferPacketNumber(
|
||||||
packetNumberLength PacketNumberLen,
|
packetNumberLength PacketNumberLen,
|
||||||
lastPacketNumber PacketNumber,
|
lastPacketNumber PacketNumber,
|
||||||
wirePacketNumber PacketNumber,
|
wirePacketNumber PacketNumber,
|
||||||
version VersionNumber,
|
|
||||||
) PacketNumber {
|
) PacketNumber {
|
||||||
var epochDelta PacketNumber
|
var epochDelta PacketNumber
|
||||||
switch packetNumberLength {
|
switch packetNumberLength {
|
||||||
case PacketNumberLen1:
|
case PacketNumberLen1:
|
||||||
epochDelta = PacketNumber(1) << 7
|
epochDelta = PacketNumber(1) << 8
|
||||||
case PacketNumberLen2:
|
case PacketNumberLen2:
|
||||||
epochDelta = PacketNumber(1) << 14
|
epochDelta = PacketNumber(1) << 16
|
||||||
|
case PacketNumberLen3:
|
||||||
|
epochDelta = PacketNumber(1) << 24
|
||||||
case PacketNumberLen4:
|
case PacketNumberLen4:
|
||||||
epochDelta = PacketNumber(1) << 30
|
epochDelta = PacketNumber(1) << 32
|
||||||
}
|
}
|
||||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||||
prevEpochBegin := epoch - epochDelta
|
prevEpochBegin := epoch - epochDelta
|
||||||
|
@ -42,11 +59,14 @@ func delta(a, b PacketNumber) PacketNumber {
|
||||||
|
|
||||||
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
|
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
|
||||||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
|
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
|
||||||
diff := uint64(packetNumber - leastUnacked)
|
diff := uint64(packetNumber - leastUnacked)
|
||||||
if diff < (1 << (14 - 1)) {
|
if diff < (1 << (16 - 1)) {
|
||||||
return PacketNumberLen2
|
return PacketNumberLen2
|
||||||
}
|
}
|
||||||
|
if diff < (1 << (24 - 1)) {
|
||||||
|
return PacketNumberLen3
|
||||||
|
}
|
||||||
return PacketNumberLen4
|
return PacketNumberLen4
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,5 +78,8 @@ func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen {
|
||||||
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
|
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
|
||||||
return PacketNumberLen2
|
return PacketNumberLen2
|
||||||
}
|
}
|
||||||
|
if packetNumber < (1 << (uint8(PacketNumberLen3) * 8)) {
|
||||||
|
return PacketNumberLen3
|
||||||
|
}
|
||||||
return PacketNumberLen4
|
return PacketNumberLen4
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,32 +7,18 @@ import (
|
||||||
// A PacketNumber in QUIC
|
// A PacketNumber in QUIC
|
||||||
type PacketNumber uint64
|
type PacketNumber uint64
|
||||||
|
|
||||||
// PacketNumberLen is the length of the packet number in bytes
|
|
||||||
type PacketNumberLen uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
|
|
||||||
PacketNumberLenInvalid PacketNumberLen = 0
|
|
||||||
// PacketNumberLen1 is a packet number length of 1 byte
|
|
||||||
PacketNumberLen1 PacketNumberLen = 1
|
|
||||||
// PacketNumberLen2 is a packet number length of 2 bytes
|
|
||||||
PacketNumberLen2 PacketNumberLen = 2
|
|
||||||
// PacketNumberLen4 is a packet number length of 4 bytes
|
|
||||||
PacketNumberLen4 PacketNumberLen = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
// The PacketType is the Long Header Type
|
// The PacketType is the Long Header Type
|
||||||
type PacketType uint8
|
type PacketType uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// PacketTypeInitial is the packet type of an Initial packet
|
// PacketTypeInitial is the packet type of an Initial packet
|
||||||
PacketTypeInitial PacketType = 0x7f
|
PacketTypeInitial PacketType = 1 + iota
|
||||||
// PacketTypeRetry is the packet type of a Retry packet
|
// PacketTypeRetry is the packet type of a Retry packet
|
||||||
PacketTypeRetry PacketType = 0x7e
|
PacketTypeRetry
|
||||||
// PacketTypeHandshake is the packet type of a Handshake packet
|
// PacketTypeHandshake is the packet type of a Handshake packet
|
||||||
PacketTypeHandshake PacketType = 0x7d
|
PacketTypeHandshake
|
||||||
// PacketType0RTT is the packet type of a 0-RTT packet
|
// PacketType0RTT is the packet type of a 0-RTT packet
|
||||||
PacketType0RTT PacketType = 0x7c
|
PacketType0RTT
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t PacketType) String() string {
|
func (t PacketType) String() string {
|
||||||
|
@ -72,11 +58,5 @@ const DefaultTCPMSS ByteCount = 1460
|
||||||
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
|
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
|
||||||
const MinInitialPacketSize = 1200
|
const MinInitialPacketSize = 1200
|
||||||
|
|
||||||
// MaxClientHellos is the maximum number of times we'll send a client hello
|
|
||||||
// The value 3 accounts for:
|
|
||||||
// * one failure due to an incorrect or missing source-address token
|
|
||||||
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
|
|
||||||
const MaxClientHellos = 3
|
|
||||||
|
|
||||||
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
|
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
|
||||||
const MinConnectionIDLenInitial = 8
|
const MinConnectionIDLenInitial = 8
|
||||||
|
|
|
@ -8,11 +8,10 @@ import (
|
||||||
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
|
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
|
||||||
type ByteOrder interface {
|
type ByteOrder interface {
|
||||||
ReadUintN(b io.ByteReader, length uint8) (uint64, error)
|
ReadUintN(b io.ByteReader, length uint8) (uint64, error)
|
||||||
ReadUint64(io.ByteReader) (uint64, error)
|
|
||||||
ReadUint32(io.ByteReader) (uint32, error)
|
ReadUint32(io.ByteReader) (uint32, error)
|
||||||
ReadUint16(io.ByteReader) (uint16, error)
|
ReadUint16(io.ByteReader) (uint16, error)
|
||||||
|
|
||||||
WriteUint64(*bytes.Buffer, uint64)
|
WriteUintN(b *bytes.Buffer, length uint8, value uint64)
|
||||||
WriteUint32(*bytes.Buffer, uint32)
|
WriteUint32(*bytes.Buffer, uint32)
|
||||||
WriteUint16(*bytes.Buffer, uint16)
|
WriteUint16(*bytes.Buffer, uint16)
|
||||||
}
|
}
|
||||||
|
|
41
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
41
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
|
@ -25,37 +25,6 @@ func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadUint64 reads a uint64
|
|
||||||
func (bigEndian) ReadUint64(b io.ByteReader) (uint64, error) {
|
|
||||||
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
|
|
||||||
var err error
|
|
||||||
if b8, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b7, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b6, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b5, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b4, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b3, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b2, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b1, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint32 reads a uint32
|
// ReadUint32 reads a uint32
|
||||||
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
|
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
|
||||||
var b1, b2, b3, b4 uint8
|
var b1, b2, b3, b4 uint8
|
||||||
|
@ -88,12 +57,10 @@ func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
|
||||||
return uint16(b1) + uint16(b2)<<8, nil
|
return uint16(b1) + uint16(b2)<<8, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteUint64 writes a uint64
|
func (bigEndian) WriteUintN(b *bytes.Buffer, length uint8, i uint64) {
|
||||||
func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
|
for j := length; j > 0; j-- {
|
||||||
b.Write([]byte{
|
b.WriteByte(uint8(i >> (8 * (j - 1))))
|
||||||
uint8(i >> 56), uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
|
}
|
||||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteUint32 writes a uint32
|
// WriteUint32 writes a uint32
|
||||||
|
|
205
vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go
generated
vendored
Normal file
205
vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go
generated
vendored
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
package wire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExtendedHeader is the header of a QUIC packet.
|
||||||
|
type ExtendedHeader struct {
|
||||||
|
Header
|
||||||
|
|
||||||
|
typeByte byte
|
||||||
|
Raw []byte
|
||||||
|
|
||||||
|
PacketNumberLen protocol.PacketNumberLen
|
||||||
|
PacketNumber protocol.PacketNumber
|
||||||
|
|
||||||
|
KeyPhase int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||||
|
// read the (now unencrypted) first byte
|
||||||
|
var err error
|
||||||
|
h.typeByte, err = b.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := b.Seek(int64(h.len)-1, io.SeekCurrent); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if h.IsLongHeader {
|
||||||
|
return h.parseLongHeader(b, v)
|
||||||
|
}
|
||||||
|
return h.parseShortHeader(b, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||||
|
if h.typeByte&0xc != 0 {
|
||||||
|
return nil, errors.New("5th and 6th bit must be 0")
|
||||||
|
}
|
||||||
|
if err := h.readPacketNumber(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||||
|
if h.typeByte&0x18 != 0 {
|
||||||
|
return nil, errors.New("4th and 5th bit must be 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.KeyPhase = int(h.typeByte&0x4) >> 2
|
||||||
|
|
||||||
|
if err := h.readPacketNumber(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
|
||||||
|
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
|
||||||
|
pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.PacketNumber = protocol.PacketNumber(pn)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes the Header.
|
||||||
|
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
|
||||||
|
if h.IsLongHeader {
|
||||||
|
return h.writeLongHeader(b, ver)
|
||||||
|
}
|
||||||
|
return h.writeShortHeader(b, ver)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
||||||
|
var packetType uint8
|
||||||
|
switch h.Type {
|
||||||
|
case protocol.PacketTypeInitial:
|
||||||
|
packetType = 0x0
|
||||||
|
case protocol.PacketType0RTT:
|
||||||
|
packetType = 0x1
|
||||||
|
case protocol.PacketTypeHandshake:
|
||||||
|
packetType = 0x2
|
||||||
|
case protocol.PacketTypeRetry:
|
||||||
|
packetType = 0x3
|
||||||
|
}
|
||||||
|
firstByte := 0xc0 | packetType<<4
|
||||||
|
if h.Type == protocol.PacketTypeRetry {
|
||||||
|
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
firstByte |= odcil
|
||||||
|
} else { // Retry packets don't have a packet number
|
||||||
|
firstByte |= uint8(h.PacketNumberLen - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteByte(firstByte)
|
||||||
|
utils.BigEndian.WriteUint32(b, uint32(h.Version))
|
||||||
|
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.WriteByte(connIDLen)
|
||||||
|
b.Write(h.DestConnectionID.Bytes())
|
||||||
|
b.Write(h.SrcConnectionID.Bytes())
|
||||||
|
|
||||||
|
switch h.Type {
|
||||||
|
case protocol.PacketTypeRetry:
|
||||||
|
b.Write(h.OrigDestConnectionID.Bytes())
|
||||||
|
b.Write(h.Token)
|
||||||
|
return nil
|
||||||
|
case protocol.PacketTypeInitial:
|
||||||
|
utils.WriteVarInt(b, uint64(len(h.Token)))
|
||||||
|
b.Write(h.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
utils.WriteVarInt(b, uint64(h.Length))
|
||||||
|
return h.writePacketNumber(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add support for the key phase
|
||||||
|
func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
||||||
|
typeByte := 0x40 | uint8(h.PacketNumberLen-1)
|
||||||
|
typeByte |= byte(h.KeyPhase << 2)
|
||||||
|
|
||||||
|
b.WriteByte(typeByte)
|
||||||
|
b.Write(h.DestConnectionID.Bytes())
|
||||||
|
return h.writePacketNumber(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
|
||||||
|
if h.PacketNumberLen == protocol.PacketNumberLenInvalid || h.PacketNumberLen > protocol.PacketNumberLen4 {
|
||||||
|
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
|
||||||
|
}
|
||||||
|
utils.BigEndian.WriteUintN(b, uint8(h.PacketNumberLen), uint64(h.PacketNumber))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLength determines the length of the Header.
|
||||||
|
func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
|
||||||
|
if h.IsLongHeader {
|
||||||
|
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
|
||||||
|
if h.Type == protocol.PacketTypeInitial {
|
||||||
|
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
||||||
|
}
|
||||||
|
return length
|
||||||
|
}
|
||||||
|
|
||||||
|
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
|
||||||
|
length += protocol.ByteCount(h.PacketNumberLen)
|
||||||
|
return length
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log logs the Header
|
||||||
|
func (h *ExtendedHeader) Log(logger utils.Logger) {
|
||||||
|
if h.IsLongHeader {
|
||||||
|
var token string
|
||||||
|
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
|
||||||
|
if len(h.Token) == 0 {
|
||||||
|
token = "Token: (empty), "
|
||||||
|
} else {
|
||||||
|
token = fmt.Sprintf("Token: %#x, ", h.Token)
|
||||||
|
}
|
||||||
|
if h.Type == protocol.PacketTypeRetry {
|
||||||
|
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
|
||||||
|
} else {
|
||||||
|
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
|
||||||
|
dcil, err := encodeSingleConnIDLen(dest)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
scil, err := encodeSingleConnIDLen(src)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return scil | dcil<<4, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
|
||||||
|
len := id.Len()
|
||||||
|
if len == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if len < 4 || len > 18 {
|
||||||
|
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
|
||||||
|
}
|
||||||
|
return byte(len - 3), nil
|
||||||
|
}
|
|
@ -2,150 +2,183 @@ package wire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"errors"
|
||||||
"fmt"
|
"io"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Header is the header of a QUIC packet.
|
// The Header is the version independent part of the header
|
||||||
type Header struct {
|
type Header struct {
|
||||||
Raw []byte
|
|
||||||
|
|
||||||
Version protocol.VersionNumber
|
Version protocol.VersionNumber
|
||||||
|
|
||||||
DestConnectionID protocol.ConnectionID
|
|
||||||
SrcConnectionID protocol.ConnectionID
|
SrcConnectionID protocol.ConnectionID
|
||||||
OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
|
DestConnectionID protocol.ConnectionID
|
||||||
|
|
||||||
PacketNumberLen protocol.PacketNumberLen
|
|
||||||
PacketNumber protocol.PacketNumber
|
|
||||||
|
|
||||||
IsVersionNegotiation bool
|
|
||||||
SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
|
|
||||||
|
|
||||||
Type protocol.PacketType
|
|
||||||
IsLongHeader bool
|
IsLongHeader bool
|
||||||
KeyPhase int
|
Type protocol.PacketType
|
||||||
Length protocol.ByteCount
|
Length protocol.ByteCount
|
||||||
|
|
||||||
Token []byte
|
Token []byte
|
||||||
|
SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet
|
||||||
|
OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet
|
||||||
|
|
||||||
|
typeByte byte
|
||||||
|
len int // how many bytes were read while parsing this header
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes the Header.
|
// ParseHeader parses the header.
|
||||||
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
|
// For short header packets: up to the packet number.
|
||||||
if h.IsLongHeader {
|
// For long header packets:
|
||||||
return h.writeLongHeader(b, ver)
|
// * if we understand the version: up to the packet number
|
||||||
|
// * if not, only the invariant part of the header
|
||||||
|
func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
|
||||||
|
startLen := b.Len()
|
||||||
|
h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return h.writeShortHeader(b, ver)
|
h.len = startLen - b.Len()
|
||||||
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: add support for the key phase
|
func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
|
||||||
func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
typeByte, err := b.ReadByte()
|
||||||
b.WriteByte(byte(0x80 | h.Type))
|
if err != nil {
|
||||||
utils.BigEndian.WriteUint32(b, uint32(h.Version))
|
return nil, err
|
||||||
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
|
}
|
||||||
|
|
||||||
|
h := &Header{
|
||||||
|
typeByte: typeByte,
|
||||||
|
IsLongHeader: typeByte&0x80 > 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.IsLongHeader {
|
||||||
|
if h.typeByte&0x40 == 0 {
|
||||||
|
return nil, errors.New("not a QUIC packet")
|
||||||
|
}
|
||||||
|
if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
if err := h.parseLongHeader(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
|
||||||
|
var err error
|
||||||
|
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) parseLongHeader(b *bytes.Reader) error {
|
||||||
|
v, err := utils.BigEndian.ReadUint32(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
b.WriteByte(connIDLen)
|
h.Version = protocol.VersionNumber(v)
|
||||||
b.Write(h.DestConnectionID.Bytes())
|
if !h.IsVersionNegotiation() && h.typeByte&0x40 == 0 {
|
||||||
b.Write(h.SrcConnectionID.Bytes())
|
return errors.New("not a QUIC packet")
|
||||||
|
|
||||||
if h.Type == protocol.PacketTypeInitial {
|
|
||||||
utils.WriteVarInt(b, uint64(len(h.Token)))
|
|
||||||
b.Write(h.Token)
|
|
||||||
}
|
}
|
||||||
|
connIDLenByte, err := b.ReadByte()
|
||||||
if h.Type == protocol.PacketTypeRetry {
|
|
||||||
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// randomize the first 4 bits
|
dcil, scil := decodeConnIDLen(connIDLenByte)
|
||||||
odcilByte := make([]byte, 1)
|
h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil)
|
||||||
_, _ = rand.Read(odcilByte) // it's safe to ignore the error here
|
if err != nil {
|
||||||
odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
|
return err
|
||||||
b.Write(odcilByte)
|
}
|
||||||
b.Write(h.OrigDestConnectionID.Bytes())
|
h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil)
|
||||||
b.Write(h.Token)
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if h.Version == 0 {
|
||||||
|
return h.parseVersionNegotiationPacket(b)
|
||||||
|
}
|
||||||
|
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
|
||||||
|
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.WriteVarInt(b, uint64(h.Length))
|
switch (h.typeByte & 0x30) >> 4 {
|
||||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
case 0x0:
|
||||||
}
|
h.Type = protocol.PacketTypeInitial
|
||||||
|
case 0x1:
|
||||||
func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
h.Type = protocol.PacketType0RTT
|
||||||
typeByte := byte(0x30)
|
case 0x2:
|
||||||
typeByte |= byte(h.KeyPhase << 6)
|
h.Type = protocol.PacketTypeHandshake
|
||||||
|
case 0x3:
|
||||||
b.WriteByte(typeByte)
|
h.Type = protocol.PacketTypeRetry
|
||||||
b.Write(h.DestConnectionID.Bytes())
|
|
||||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLength determines the length of the Header.
|
|
||||||
func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount {
|
|
||||||
if h.IsLongHeader {
|
|
||||||
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
|
|
||||||
if h.Type == protocol.PacketTypeInitial {
|
|
||||||
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
|
||||||
}
|
|
||||||
return length
|
|
||||||
}
|
}
|
||||||
|
|
||||||
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
|
|
||||||
length += protocol.ByteCount(h.PacketNumberLen)
|
|
||||||
return length
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log logs the Header
|
|
||||||
func (h *Header) Log(logger utils.Logger) {
|
|
||||||
if h.IsLongHeader {
|
|
||||||
if h.Version == 0 {
|
|
||||||
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
|
|
||||||
} else {
|
|
||||||
var token string
|
|
||||||
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
|
|
||||||
if len(h.Token) == 0 {
|
|
||||||
token = "Token: (empty), "
|
|
||||||
} else {
|
|
||||||
token = fmt.Sprintf("Token: %#x, ", h.Token)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if h.Type == protocol.PacketTypeRetry {
|
if h.Type == protocol.PacketTypeRetry {
|
||||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
|
odcil := decodeSingleConnIDLen(h.typeByte & 0xf)
|
||||||
return
|
h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
|
h.Token = make([]byte, b.Len())
|
||||||
|
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
return nil
|
||||||
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.Type == protocol.PacketTypeInitial {
|
||||||
|
tokenLen, err := utils.ReadVarInt(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if tokenLen > uint64(b.Len()) {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
h.Token = make([]byte, tokenLen)
|
||||||
|
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pl, err := utils.ReadVarInt(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.Length = protocol.ByteCount(pl)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
|
func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error {
|
||||||
dcil, err := encodeSingleConnIDLen(dest)
|
if b.Len() == 0 {
|
||||||
if err != nil {
|
return qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
|
||||||
return 0, err
|
|
||||||
}
|
}
|
||||||
scil, err := encodeSingleConnIDLen(src)
|
h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
|
||||||
|
for i := 0; b.Len() > 0; i++ {
|
||||||
|
v, err := utils.BigEndian.ReadUint32(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return qerr.InvalidVersionNegotiationPacket
|
||||||
}
|
}
|
||||||
return scil | dcil<<4, nil
|
h.SupportedVersions[i] = protocol.VersionNumber(v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
|
// IsVersionNegotiation says if this a version negotiation packet
|
||||||
len := id.Len()
|
func (h *Header) IsVersionNegotiation() bool {
|
||||||
if len == 0 {
|
return h.IsLongHeader && h.Version == 0
|
||||||
return 0, nil
|
}
|
||||||
}
|
|
||||||
if len < 4 || len > 18 {
|
// ParseExtended parses the version dependent part of the header.
|
||||||
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
|
// The Reader has to be set such that it points to the first byte of the header.
|
||||||
}
|
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
|
||||||
return byte(len - 3), nil
|
return h.toExtendedHeader().parse(b, ver)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Header) toExtendedHeader() *ExtendedHeader {
|
||||||
|
return &ExtendedHeader{Header: *h}
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
|
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
|
||||||
|
|
|
@ -162,75 +162,46 @@ func (h *packetHandlerMap) listen() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
|
||||||
rcvTime := time.Now()
|
|
||||||
|
|
||||||
r := bytes.NewReader(data)
|
r := bytes.NewReader(data)
|
||||||
iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen)
|
hdr, err := wire.ParseHeader(r, h.connIDLen)
|
||||||
// drop the packet if we can't parse the header
|
// drop the packet if we can't parse the header
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error parsing invariant header: %s", err)
|
return fmt.Errorf("error parsing header: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &receivedPacket{
|
||||||
|
remoteAddr: addr,
|
||||||
|
hdr: hdr,
|
||||||
|
data: data,
|
||||||
|
rcvTime: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
handlerEntry, handlerFound := h.handlers[string(iHdr.DestConnectionID)]
|
defer h.mutex.RUnlock()
|
||||||
server := h.server
|
|
||||||
|
handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)]
|
||||||
|
|
||||||
var sentBy protocol.Perspective
|
|
||||||
var version protocol.VersionNumber
|
|
||||||
var handlePacket func(*receivedPacket)
|
|
||||||
if handlerFound { // existing session
|
if handlerFound { // existing session
|
||||||
handler := handlerEntry.handler
|
handlerEntry.handler.handlePacket(p)
|
||||||
sentBy = handler.GetPerspective().Opposite()
|
return nil
|
||||||
version = handler.GetVersion()
|
}
|
||||||
handlePacket = handler.handlePacket
|
// No session found.
|
||||||
} else { // no session found
|
// This might be a stateless reset.
|
||||||
// this might be a stateless reset
|
if !hdr.IsLongHeader {
|
||||||
if !iHdr.IsLongHeader {
|
|
||||||
if len(data) >= protocol.MinStatelessResetSize {
|
if len(data) >= protocol.MinStatelessResetSize {
|
||||||
var token [16]byte
|
var token [16]byte
|
||||||
copy(token[:], data[len(data)-16:])
|
copy(token[:], data[len(data)-16:])
|
||||||
if sess, ok := h.resetTokens[token]; ok {
|
if sess, ok := h.resetTokens[token]; ok {
|
||||||
h.mutex.RUnlock()
|
|
||||||
sess.destroy(errors.New("received a stateless reset"))
|
sess.destroy(errors.New("received a stateless reset"))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO(#943): send a stateless reset
|
// TODO(#943): send a stateless reset
|
||||||
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID)
|
||||||
}
|
}
|
||||||
if server == nil { // no server set
|
if h.server == nil { // no server set
|
||||||
h.mutex.RUnlock()
|
return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
|
||||||
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
|
||||||
}
|
}
|
||||||
handlePacket = server.handlePacket
|
h.server.handlePacket(p)
|
||||||
sentBy = protocol.PerspectiveClient
|
|
||||||
version = iHdr.Version
|
|
||||||
}
|
|
||||||
h.mutex.RUnlock()
|
|
||||||
|
|
||||||
hdr, err := iHdr.Parse(r, sentBy, version)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error parsing header: %s", err)
|
|
||||||
}
|
|
||||||
hdr.Raw = data[:len(data)-r.Len()]
|
|
||||||
packetData := data[len(data)-r.Len():]
|
|
||||||
|
|
||||||
if hdr.IsLongHeader {
|
|
||||||
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
|
|
||||||
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
|
|
||||||
}
|
|
||||||
if protocol.ByteCount(len(packetData))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
|
|
||||||
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(packetData)+int(hdr.PacketNumberLen), hdr.Length)
|
|
||||||
}
|
|
||||||
packetData = packetData[:int(hdr.Length)-int(hdr.PacketNumberLen)]
|
|
||||||
// TODO(#1312): implement parsing of compound packets
|
|
||||||
}
|
|
||||||
|
|
||||||
handlePacket(&receivedPacket{
|
|
||||||
remoteAddr: addr,
|
|
||||||
header: hdr,
|
|
||||||
data: packetData,
|
|
||||||
rcvTime: rcvTime,
|
|
||||||
})
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ type packer interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type packedPacket struct {
|
type packedPacket struct {
|
||||||
header *wire.Header
|
header *wire.ExtendedHeader
|
||||||
raw []byte
|
raw []byte
|
||||||
frames []wire.Frame
|
frames []wire.Frame
|
||||||
encryptionLevel protocol.EncryptionLevel
|
encryptionLevel protocol.EncryptionLevel
|
||||||
|
@ -397,14 +397,13 @@ func (p *packetPacker) composeNextPacket(
|
||||||
return frames, nil
|
return frames, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header {
|
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
|
||||||
pn, pnLen := p.pnManager.PeekPacketNumber()
|
pn, pnLen := p.pnManager.PeekPacketNumber()
|
||||||
header := &wire.Header{
|
header := &wire.ExtendedHeader{}
|
||||||
PacketNumber: pn,
|
header.PacketNumber = pn
|
||||||
PacketNumberLen: pnLen,
|
header.PacketNumberLen = pnLen
|
||||||
Version: p.version,
|
header.Version = p.version
|
||||||
DestConnectionID: p.destConnID,
|
header.DestConnectionID = p.destConnID
|
||||||
}
|
|
||||||
|
|
||||||
if encLevel != protocol.Encryption1RTT {
|
if encLevel != protocol.Encryption1RTT {
|
||||||
header.IsLongHeader = true
|
header.IsLongHeader = true
|
||||||
|
@ -424,8 +423,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetPacker) writeAndSealPacket(
|
func (p *packetPacker) writeAndSealPacket(
|
||||||
header *wire.Header,
|
header *wire.ExtendedHeader, frames []wire.Frame,
|
||||||
frames []wire.Frame,
|
|
||||||
sealer handshake.Sealer,
|
sealer handshake.Sealer,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
raw := *getPacketBuffer()
|
raw := *getPacketBuffer()
|
||||||
|
@ -450,7 +448,7 @@ func (p *packetPacker) writeAndSealPacket(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := header.Write(buffer, p.perspective, p.version); err != nil {
|
if err := header.Write(buffer, p.version); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
payloadStartIndex := buffer.Len()
|
payloadStartIndex := buffer.Len()
|
||||||
|
|
|
@ -35,7 +35,7 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) {
|
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
|
||||||
buf := *getPacketBuffer()
|
buf := *getPacketBuffer()
|
||||||
buf = buf[:0]
|
buf = buf[:0]
|
||||||
defer putPacketBuffer(&buf)
|
defer putPacketBuffer(&buf)
|
||||||
|
|
|
@ -21,7 +21,6 @@ type packetHandler interface {
|
||||||
handlePacket(*receivedPacket)
|
handlePacket(*receivedPacket)
|
||||||
io.Closer
|
io.Closer
|
||||||
destroy(error)
|
destroy(error)
|
||||||
GetVersion() protocol.VersionNumber
|
|
||||||
GetPerspective() protocol.Perspective
|
GetPerspective() protocol.Perspective
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,7 +98,8 @@ var _ Listener = &server{}
|
||||||
var _ unknownPacketHandler = &server{}
|
var _ unknownPacketHandler = &server{}
|
||||||
|
|
||||||
// ListenAddr creates a QUIC server listening on a given address.
|
// ListenAddr creates a QUIC server listening on a given address.
|
||||||
// The tls.Config must not be nil, the quic.Config may be nil.
|
// The tls.Config must not be nil and must contain a certificate configuration.
|
||||||
|
// The quic.Config may be nil, in that case the default values will be used.
|
||||||
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
|
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
|
||||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -118,7 +118,11 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen listens for QUIC connections on a given net.PacketConn.
|
// Listen listens for QUIC connections on a given net.PacketConn.
|
||||||
// The tls.Config must not be nil, the quic.Config may be nil.
|
// A single PacketConn only be used for a single call to Listen.
|
||||||
|
// The PacketConn can be used for simultaneous calls to Dial.
|
||||||
|
// QUIC connection IDs are used for demultiplexing the different connections.
|
||||||
|
// The tls.Config must not be nil and must contain a certificate configuration.
|
||||||
|
// The quic.Config may be nil, in that case the default values will be used.
|
||||||
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
|
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
|
||||||
return listen(conn, tlsConf, config)
|
return listen(conn, tlsConf, config)
|
||||||
}
|
}
|
||||||
|
@ -300,23 +304,17 @@ func (s *server) Addr() net.Addr {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) handlePacket(p *receivedPacket) {
|
func (s *server) handlePacket(p *receivedPacket) {
|
||||||
if err := s.handlePacketImpl(p); err != nil {
|
hdr := p.hdr
|
||||||
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) handlePacketImpl(p *receivedPacket) error {
|
|
||||||
hdr := p.header
|
|
||||||
|
|
||||||
// send a Version Negotiation Packet if the client is speaking a different protocol version
|
// send a Version Negotiation Packet if the client is speaking a different protocol version
|
||||||
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
||||||
return s.sendVersionNegotiationPacket(p)
|
go s.sendVersionNegotiationPacket(p)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if hdr.Type == protocol.PacketTypeInitial {
|
if hdr.Type == protocol.PacketTypeInitial {
|
||||||
go s.handleInitial(p)
|
go s.handleInitial(p)
|
||||||
}
|
}
|
||||||
// TODO(#943): send Stateless Reset
|
// TODO(#943): send Stateless Reset
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) handleInitial(p *receivedPacket) {
|
func (s *server) handleInitial(p *receivedPacket) {
|
||||||
|
@ -335,11 +333,11 @@ func (s *server) handleInitial(p *receivedPacket) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
|
func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
|
||||||
hdr := p.header
|
hdr := p.hdr
|
||||||
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||||
return nil, nil, errors.New("dropping Initial packet with too short connection ID")
|
return nil, nil, errors.New("dropping Initial packet with too short connection ID")
|
||||||
}
|
}
|
||||||
if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize {
|
if len(p.data) < protocol.MinInitialPacketSize {
|
||||||
return nil, nil, errors.New("dropping too small Initial packet")
|
return nil, nil, errors.New("dropping too small Initial packet")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -358,7 +356,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con
|
||||||
if !s.config.AcceptCookie(p.remoteAddr, cookie) {
|
if !s.config.AcceptCookie(p.remoteAddr, cookie) {
|
||||||
// Log the Initial packet now.
|
// Log the Initial packet now.
|
||||||
// If no Retry is sent, the packet will be logged by the session.
|
// If no Retry is sent, the packet will be logged by the session.
|
||||||
p.header.Log(s.logger)
|
(&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger)
|
||||||
return nil, nil, s.sendRetry(p.remoteAddr, hdr)
|
return nil, nil, s.sendRetry(p.remoteAddr, hdr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -431,19 +429,18 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
replyHdr := &wire.Header{
|
replyHdr := &wire.ExtendedHeader{}
|
||||||
IsLongHeader: true,
|
replyHdr.IsLongHeader = true
|
||||||
Type: protocol.PacketTypeRetry,
|
replyHdr.Type = protocol.PacketTypeRetry
|
||||||
Version: hdr.Version,
|
replyHdr.Version = hdr.Version
|
||||||
SrcConnectionID: connID,
|
replyHdr.SrcConnectionID = connID
|
||||||
DestConnectionID: hdr.SrcConnectionID,
|
replyHdr.DestConnectionID = hdr.SrcConnectionID
|
||||||
OrigDestConnectionID: hdr.DestConnectionID,
|
replyHdr.OrigDestConnectionID = hdr.DestConnectionID
|
||||||
Token: token,
|
replyHdr.Token = token
|
||||||
}
|
|
||||||
s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
|
s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
|
||||||
replyHdr.Log(s.logger)
|
replyHdr.Log(s.logger)
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil {
|
if err := replyHdr.Write(buf, hdr.Version); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
|
if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
|
||||||
|
@ -452,14 +449,15 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
|
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
|
||||||
hdr := p.header
|
hdr := p.hdr
|
||||||
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
|
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
|
||||||
|
|
||||||
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
|
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
s.logger.Debugf("Error composing Version Negotiation: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil {
|
||||||
|
s.logger.Debugf("Error sending Version Negotiation: %s", err)
|
||||||
}
|
}
|
||||||
_, err = s.conn.WriteTo(data, p.remoteAddr)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ func (s *serverSession) handlePacket(p *receivedPacket) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
|
func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
|
||||||
hdr := p.header
|
hdr := p.hdr
|
||||||
|
|
||||||
// Probably an old packet that was sent by the client before the version was negotiated.
|
// Probably an old packet that was sent by the client before the version was negotiated.
|
||||||
// It is safe to drop it.
|
// It is safe to drop it.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -21,7 +22,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type unpacker interface {
|
type unpacker interface {
|
||||||
Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error)
|
Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamGetter interface {
|
type streamGetter interface {
|
||||||
|
@ -52,7 +53,7 @@ type cryptoStreamHandler interface {
|
||||||
|
|
||||||
type receivedPacket struct {
|
type receivedPacket struct {
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
header *wire.Header
|
hdr *wire.Header
|
||||||
data []byte
|
data []byte
|
||||||
rcvTime time.Time
|
rcvTime time.Time
|
||||||
}
|
}
|
||||||
|
@ -113,7 +114,6 @@ type session struct {
|
||||||
|
|
||||||
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
|
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
|
||||||
receivedFirstForwardSecurePacket bool
|
receivedFirstForwardSecurePacket bool
|
||||||
lastRcvdPacketNumber protocol.PacketNumber
|
|
||||||
// Used to calculate the next packet number from the truncated wire
|
// Used to calculate the next packet number from the truncated wire
|
||||||
// representation, and sent back in public reset packets
|
// representation, and sent back in public reset packets
|
||||||
largestRcvdPacketNumber protocol.PacketNumber
|
largestRcvdPacketNumber protocol.PacketNumber
|
||||||
|
@ -289,7 +289,7 @@ var newClientSession = func(
|
||||||
|
|
||||||
func (s *session) preSetup() {
|
func (s *session) preSetup() {
|
||||||
s.rttStats = &congestion.RTTStats{}
|
s.rttStats = &congestion.RTTStats{}
|
||||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version)
|
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
|
||||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
|
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
|
||||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||||
protocol.InitialMaxData,
|
protocol.InitialMaxData,
|
||||||
|
@ -374,7 +374,7 @@ runLoop:
|
||||||
}
|
}
|
||||||
// This is a bit unclean, but works properly, since the packet always
|
// This is a bit unclean, but works properly, since the packet always
|
||||||
// begins with the public header and we never copy it.
|
// begins with the public header and we never copy it.
|
||||||
putPacketBuffer(&p.header.Raw)
|
// TODO: putPacketBuffer(&p.extHdr.Raw)
|
||||||
case <-s.handshakeCompleteChan:
|
case <-s.handshakeCompleteChan:
|
||||||
s.handleHandshakeComplete()
|
s.handleHandshakeComplete()
|
||||||
}
|
}
|
||||||
|
@ -479,24 +479,41 @@ func (s *session) handleHandshakeComplete() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handlePacketImpl(p *receivedPacket) error {
|
func (s *session) handlePacketImpl(p *receivedPacket) error {
|
||||||
hdr := p.header
|
|
||||||
// The server can change the source connection ID with the first Handshake packet.
|
// The server can change the source connection ID with the first Handshake packet.
|
||||||
// After this, all packets with a different source connection have to be ignored.
|
// After this, all packets with a different source connection have to be ignored.
|
||||||
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
|
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
|
||||||
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.header.SrcConnectionID, s.destConnID)
|
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.rcvTime = time.Now()
|
data := p.data
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
hdr, err := p.hdr.ParseExtended(r, s.version)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error parsing extended header: %s", err)
|
||||||
|
}
|
||||||
|
hdr.Raw = data[:len(data)-r.Len()]
|
||||||
|
data = data[len(data)-r.Len():]
|
||||||
|
|
||||||
|
if hdr.IsLongHeader {
|
||||||
|
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
|
||||||
|
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
|
||||||
|
}
|
||||||
|
if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
|
||||||
|
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length)
|
||||||
|
}
|
||||||
|
data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)]
|
||||||
|
// TODO(#1312): implement parsing of compound packets
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate packet number
|
// Calculate packet number
|
||||||
hdr.PacketNumber = protocol.InferPacketNumber(
|
hdr.PacketNumber = protocol.InferPacketNumber(
|
||||||
hdr.PacketNumberLen,
|
hdr.PacketNumberLen,
|
||||||
s.largestRcvdPacketNumber,
|
s.largestRcvdPacketNumber,
|
||||||
hdr.PacketNumber,
|
hdr.PacketNumber,
|
||||||
s.version,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, p.data)
|
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
|
||||||
if s.logger.Debug() {
|
if s.logger.Debug() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID)
|
s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID)
|
||||||
|
@ -530,7 +547,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.lastRcvdPacketNumber = hdr.PacketNumber
|
|
||||||
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
|
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
|
||||||
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
|
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
|
||||||
|
|
||||||
|
@ -543,10 +559,10 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.handleFrames(packet.frames, packet.encryptionLevel)
|
return s.handleFrames(packet.frames, hdr.PacketNumber, packet.encryptionLevel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error {
|
func (s *session) handleFrames(fs []wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
|
||||||
for _, ff := range fs {
|
for _, ff := range fs {
|
||||||
var err error
|
var err error
|
||||||
wire.LogFrame(s.logger, ff, false)
|
wire.LogFrame(s.logger, ff, false)
|
||||||
|
@ -556,7 +572,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
|
||||||
case *wire.StreamFrame:
|
case *wire.StreamFrame:
|
||||||
err = s.handleStreamFrame(frame, encLevel)
|
err = s.handleStreamFrame(frame, encLevel)
|
||||||
case *wire.AckFrame:
|
case *wire.AckFrame:
|
||||||
err = s.handleAckFrame(frame, encLevel)
|
err = s.handleAckFrame(frame, pn, encLevel)
|
||||||
case *wire.ConnectionCloseFrame:
|
case *wire.ConnectionCloseFrame:
|
||||||
s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
|
s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
|
||||||
case *wire.ResetStreamFrame:
|
case *wire.ResetStreamFrame:
|
||||||
|
@ -702,8 +718,8 @@ func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) {
|
||||||
s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data})
|
s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error {
|
func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
|
||||||
if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil {
|
if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastNetworkActivityTime); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
|
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
|
||||||
|
@ -1065,14 +1081,14 @@ func (s *session) scheduleSending() {
|
||||||
|
|
||||||
func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) {
|
func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) {
|
||||||
if s.handshakeComplete {
|
if s.handshakeComplete {
|
||||||
s.logger.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data))
|
s.logger.Debugf("Received undecryptable packet from %s after the handshake (%d bytes)", p.remoteAddr.String(), len(p.data))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
|
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
|
||||||
s.logger.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber)
|
s.logger.Infof("Dropping undecrytable packet (%d bytes). Undecryptable packet queue full.", len(p.data))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber)
|
s.logger.Infof("Queueing packet (%d bytes) for later decryption", len(p.data))
|
||||||
s.undecryptablePackets = append(s.undecryptablePackets, p)
|
s.undecryptablePackets = append(s.undecryptablePackets, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue