update quic

pull/1435/head v4.6.2
Darien Raymond 2018-11-23 17:04:53 +01:00
parent 6870ead73e
commit 19926b8e4f
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
214 changed files with 14826 additions and 33556 deletions

View File

@ -5,37 +5,85 @@ import (
"sync"
"time"
"v2ray.com/core/transport/internet/tls"
quic "github.com/lucas-clemente/quic-go"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/tls"
)
type clientSessions struct {
access sync.Mutex
sessions map[net.Destination]quic.Session
sessions map[net.Destination][]quic.Session
}
func (s *clientSessions) getSession(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (quic.Session, error) {
func removeInactiveSessions(sessions []quic.Session) []quic.Session {
lastActive := 0
for _, s := range sessions {
active := true
select {
case <-s.Context().Done():
active = false
default:
}
if active {
sessions[lastActive] = s
lastActive++
}
}
if lastActive < len(sessions) {
for i := 0; i < len(sessions); i++ {
sessions[i] = nil
}
sessions = sessions[:lastActive]
}
return sessions
}
func openStream(sessions []quic.Session) (quic.Stream, net.Addr, error) {
for _, s := range sessions {
stream, err := s.OpenStream()
if err != nil {
newError("failed to create stream").Base(err).WriteToLog()
continue
}
return stream, s.LocalAddr(), nil
}
return nil, nil, nil
}
func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
s.access.Lock()
defer s.access.Unlock()
if s.sessions == nil {
s.sessions = make(map[net.Destination]quic.Session)
s.sessions = make(map[net.Destination][]quic.Session)
}
dest := net.DestinationFromAddr(destAddr)
if session, found := s.sessions[dest]; found {
select {
case <-session.Context().Done():
// Session has been closed. Creating a new one.
default:
return session, nil
}
var sessions []quic.Session
if s, found := s.sessions[dest]; found {
sessions = s
}
sessions = removeInactiveSessions(sessions)
s.sessions[dest] = sessions
stream, local, err := openStream(sessions)
if err != nil {
return nil, err
}
if stream != nil {
return &interConn{
stream: stream,
local: local,
remote: destAddr,
}, nil
}
rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
@ -47,13 +95,12 @@ func (s *clientSessions) getSession(destAddr net.Addr, config *Config, tlsConfig
}
quicConfig := &quic.Config{
Versions: []quic.VersionNumber{quic.VersionMilestone0_10_0},
ConnectionIDLength: 12,
KeepAlive: true,
HandshakeTimeout: time.Second * 4,
IdleTimeout: time.Second * 300,
MaxReceiveStreamFlowControlWindow: 128 * 1024,
MaxReceiveConnectionFlowControlWindow: 512 * 1024,
IdleTimeout: time.Second * 60,
MaxReceiveStreamFlowControlWindow: 256 * 1024,
MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
MaxIncomingUniStreams: -1,
}
@ -69,8 +116,16 @@ func (s *clientSessions) getSession(destAddr net.Addr, config *Config, tlsConfig
return nil, err
}
s.sessions[dest] = session
return session, nil
s.sessions[dest] = append(sessions, session)
stream, err = session.OpenStream()
if err != nil {
return nil, err
}
return &interConn{
stream: stream,
local: session.LocalAddr(),
remote: destAddr,
}, nil
}
var client clientSessions
@ -91,21 +146,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
config := streamSettings.ProtocolSettings.(*Config)
session, err := client.getSession(destAddr, config, tlsConfig, streamSettings.SocketSettings)
if err != nil {
return nil, err
}
conn, err := session.OpenStreamSync()
if err != nil {
return nil, err
}
return &interConn{
stream: conn,
local: session.LocalAddr(),
remote: destAddr,
}, nil
return client.openConnection(destAddr, config, tlsConfig, streamSettings.SocketSettings)
}
func init() {

View File

@ -84,14 +84,13 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
}
quicConfig := &quic.Config{
Versions: []quic.VersionNumber{quic.VersionMilestone0_10_0},
ConnectionIDLength: 12,
KeepAlive: true,
HandshakeTimeout: time.Second * 4,
IdleTimeout: time.Second * 300,
MaxReceiveStreamFlowControlWindow: 128 * 1024,
MaxReceiveConnectionFlowControlWindow: 512 * 1024,
MaxIncomingStreams: 256,
IdleTimeout: time.Second * 60,
MaxReceiveStreamFlowControlWindow: 256 * 1024,
MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
MaxIncomingStreams: 64,
MaxIncomingUniStreams: -1,
}

View File

@ -8,12 +8,19 @@
[![Windows Build Status](https://img.shields.io/appveyor/ci/lucas-clemente/quic-go/master.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
[![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/)
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go.
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go. It roughly implements the [IETF QUIC draft](https://github.com/quicwg/base-drafts), although we don't fully support any of the draft versions at the moment.
## Roadmap
## Version compatibility
quic-go is compatible with the current version(s) of Google Chrome and QUIC as deployed on Google's servers. We're actively tracking the development of the Chrome code to ensure compatibility as the protocol evolves. In that process, we're dropping support for old QUIC versions.
As Google's QUIC versions are expected to converge towards the [IETF QUIC draft](https://github.com/quicwg/base-drafts), quic-go will eventually implement that draft.
Since quic-go is under active development, there's no guarantee that two builds of different commits are interoperable. The QUIC version used in the *master* branch is just a placeholder, and should not be considered stable.
If you want to use quic-go as a library in other projects, please consider using a [tagged release](https://github.com/lucas-clemente/quic-go/releases). These releases expose [experimental QUIC versions](https://github.com/quicwg/base-drafts/wiki/QUIC-Versions), which are guaranteed to be stable.
## Google QUIC
quic-go used to support both the QUIC versions supported by Google Chrome and QUIC as deployed on Google's servers, as well as IETF QUIC. Due to the divergence of the two protocols, we decided to not support both versions any more.
The *master* branch **only** supports IETF QUIC. For Google QUIC support, please refer to the [gquic branch](https://github.com/lucas-clemente/quic-go/tree/gquic).
## Guides
@ -27,31 +34,19 @@ Running tests:
go test ./...
### Running the example server
### HTTP mapping
go run example/main.go -www /var/www/
Using the `quic_client` from chromium:
quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
Using Chrome:
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
We're currently not implementing the HTTP mapping as described in the [QUIC over HTTP draft](https://quicwg.org/base-drafts/draft-ietf-quic-http.html). The HTTP mapping here is a leftover from Google QUIC.
### QUIC without HTTP/2
Take a look at [this echo example](example/echo/echo.go).
### Using the example client
go run example/client/main.go https://clemente.io
## Usage
### As a server
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go:
```go
http.Handle("/", http.FileServer(http.Dir(wwwDir)))

View File

@ -10,6 +10,9 @@ environment:
- GOARCH: 386
- GOARCH: amd64
hosts:
quic.clemente.io: 127.0.0.1
clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
install:
@ -19,7 +22,6 @@ install:
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
- echo %PATH%
- echo %GOPATH%
- git submodule update --init --recursive
- go get github.com/onsi/ginkgo/ginkgo
- go get github.com/onsi/gomega
- go version

View File

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
@ -9,12 +8,11 @@ import (
"net"
"sync"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
type client struct {
@ -25,29 +23,25 @@ type client struct {
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
hostname string
packetHandlers packetHandlerManager
token []byte
numRetries int
token []byte
versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config
mintConf *mint.Config
config *Config
tlsConf *tls.Config
config *Config
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry)
initialVersion protocol.VersionNumber
version protocol.VersionNumber
handshakeChan chan struct{}
closeCallback func(protocol.ConnectionID)
session quicSession
@ -128,18 +122,11 @@ func dialContext(
createdPacketConn bool,
) (Session, error) {
config = populateClientConfig(config, createdPacketConn)
if !createdPacketConn {
for _, v := range config.Versions {
if v == protocol.Version44 {
return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
}
}
}
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, createdPacketConn)
if err != nil {
return nil, err
}
@ -156,16 +143,14 @@ func newClient(
config *Config,
tlsConf *tls.Config,
host string,
closeCallback func(protocol.ConnectionID),
createdPacketConn bool,
) (*client, error) {
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
if tlsConf == nil {
tlsConf = &tls.Config{}
}
if hostname == "" {
if tlsConf.ServerName == "" {
var err error
hostname, _, err = net.SplitHostPort(host)
tlsConf.ServerName, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
@ -179,19 +164,13 @@ func newClient(
}
}
}
onClose := func(protocol.ConnectionID) {}
if closeCallback != nil {
onClose = closeCallback
}
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn,
hostname: hostname,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, c.generateConnectionIDs()
@ -219,11 +198,11 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
}
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
@ -241,17 +220,11 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
if connIDLen == 0 && !createdPacketConn {
connIDLen = protocol.DefaultConnectionIDLength
}
for _, v := range versions {
if v == protocol.Version44 {
connIDLen = 0
}
}
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
IdleTimeout: idleTimeout,
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
ConnectionIDLength: connIDLen,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
@ -262,75 +235,26 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
}
func (c *client) generateConnectionIDs() error {
connIDLen := protocol.ConnectionIDLenGQUIC
if c.version.UsesTLS() {
connIDLen = c.config.ConnectionIDLength
}
srcConnID, err := generateConnectionID(connIDLen)
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
if err != nil {
return err
}
destConnID := srcConnID
if c.version.UsesTLS() {
destConnID, err = generateConnectionIDForInitial()
if err != nil {
return err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return err
}
c.srcConnID = srcConnID
c.destConnID = destConnID
if c.version == protocol.Version44 {
c.srcConnID = nil
}
return nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
var err error
if c.version.UsesTLS() {
err = c.dialTLS(ctx)
} else {
err = c.dialGQUIC(ctx)
}
return err
}
func (c *client) dialGQUIC(ctx context.Context) error {
if err := c.createNewGQUICSession(); err != nil {
if err := c.createNewTLSSession(c.version); err != nil {
return err
}
err := c.establishSecureConnection(ctx)
if err == errCloseSessionForNewVersion {
return c.dial(ctx)
}
return err
}
func (c *client) dialTLS(ctx context.Context) error {
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
DisableMigration: true,
}
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
}
mintConf.ExtensionHandler = extHandler
mintConf.ServerName = c.hostname
c.mintConf = mintConf
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
err = c.establishSecureConnection(ctx)
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
return c.dial(ctx)
}
@ -340,9 +264,9 @@ func (c *client) dialTLS(ctx context.Context) error {
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
// - when the connection is forward-secure
func (c *client) establishSecureConnection(ctx context.Context) error {
errorChan := make(chan error, 1)
@ -387,35 +311,14 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
return err
}
if !c.version.UsesIETFHeaderFormat() {
connID := p.header.DestConnectionID
// reject packets with truncated connection id if we didn't request truncation
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
return errors.New("received packet with truncated connection ID, but didn't request truncation")
}
// reject packets with the wrong connection ID
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
}
if p.header.ResetFlag {
return c.handlePublicReset(p)
}
} else {
// reject packets with the wrong connection ID
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
// reject packets with the wrong connection ID
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
if p.header.IsLongHeader {
switch p.header.Type {
case protocol.PacketTypeRetry:
c.handleRetryPacket(p.header)
return nil
case protocol.PacketTypeHandshake, protocol.PacketType0RTT:
default:
return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
}
if p.header.Type == protocol.PacketTypeRetry {
c.handleRetryPacket(p.header)
return nil
}
// this is the first packet we are receiving
@ -428,22 +331,6 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
return nil
}
func (c *client) handlePublicReset(p *receivedPacket) error {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
return errors.New("Received a spoofed Public Reset")
}
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
if err != nil {
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
}
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
return nil
}
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
@ -483,77 +370,56 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
func (c *client) handleRetryPacket(hdr *wire.Header) {
c.logger.Debugf("<- Received Retry")
hdr.Log(c.logger)
// A server that performs multiple retries must use a source connection ID of at least 8 bytes.
// Only a server that won't send additional Retries can use shorter connection IDs.
if hdr.OrigDestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
c.logger.Debugf("Received a Retry with a too short Original Destination Connection ID: %d bytes, must have at least %d bytes.", hdr.OrigDestConnectionID.Len(), protocol.MinConnectionIDLenInitial)
return
}
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Received spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
return
}
c.numRetries++
if c.numRetries > protocol.MaxRetries {
c.session.destroy(qerr.CryptoTooManyRejects)
if hdr.SrcConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
return
}
// If a token is already set, this means that we already received a Retry from the server.
// Ignore this Retry packet.
if len(c.token) > 0 {
c.logger.Debugf("Ignoring Retry, since a Retry was already received.")
return
}
c.origDestConnID = c.destConnID
c.destConnID = hdr.SrcConnectionID
c.token = hdr.Token
c.session.destroy(errCloseSessionForRetry)
}
func (c *client) createNewGQUICSession() error {
func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
params := &handshake.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
InitialMaxStreamDataUni: protocol.InitialMaxStreamData,
InitialMaxData: protocol.InitialMaxData,
IdleTimeout: c.config.IdleTimeout,
MaxBidiStreams: uint64(c.config.MaxIncomingStreams),
MaxUniStreams: uint64(c.config.MaxIncomingUniStreams),
DisableMigration: true,
}
c.mutex.Lock()
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: c.closeCallback,
retireConnectionIDImpl: c.packetHandlers.Retire,
removeConnectionIDImpl: c.packetHandlers.Remove,
}
sess, err := newClientSession(
c.conn,
runner,
c.hostname,
c.version,
c.destConnID,
c.srcConnID,
c.tlsConf,
c.config,
c.initialVersion,
c.negotiatedVersions,
c.logger,
)
if err != nil {
return err
}
c.session = sess
c.packetHandlers.Add(c.srcConnID, c)
if c.config.RequestConnectionIDOmission {
c.packetHandlers.Add(protocol.ConnectionID{}, c)
}
return nil
}
func (c *client) createNewTLSSession(
paramsChan <-chan handshake.TransportParameters,
version protocol.VersionNumber,
) error {
c.mutex.Lock()
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: c.closeCallback,
}
sess, err := newTLSClientSession(
c.conn,
runner,
c.token,
c.origDestConnID,
c.destConnID,
c.srcConnID,
c.config,
c.mintConf,
paramsChan,
1,
c.tlsConf,
params,
c.initialVersion,
c.logger,
c.version,
)

View File

@ -1,41 +1,108 @@
package quic
import (
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStream interface {
StreamID() protocol.StreamID
io.Reader
// for receiving data
HandleCryptoFrame(*wire.CryptoFrame) error
GetCryptoData() []byte
Finish() error
// for sending data
io.Writer
handleStreamFrame(*wire.StreamFrame) error
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
setReadOffset(protocol.ByteCount)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
HasData() bool
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
}
type cryptoStreamImpl struct {
*stream
queue *frameSorter
msgBuf []byte
highestOffset protocol.ByteCount
finished bool
writeOffset protocol.ByteCount
writeBuf []byte
}
var _ cryptoStream = &cryptoStreamImpl{}
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream {
str := newStream(version.CryptoStreamID(), sender, flowController, version)
return &cryptoStreamImpl{str}
func newCryptoStream() cryptoStream {
return &cryptoStreamImpl{
queue: newFrameSorter(),
}
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPos = offset
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset)
}
if s.finished {
if highestOffset > s.highestOffset {
// reject crypto data received after this stream was already finished
return errors.New("received crypto data after change of encryption level")
}
// ignore data with a smaller offset than the highest received
// could e.g. be a retransmission
return nil
}
s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset)
if err := s.queue.Push(f.Data, f.Offset, false); err != nil {
return err
}
for {
data, _ := s.queue.Pop()
if data == nil {
return nil
}
s.msgBuf = append(s.msgBuf, data...)
}
}
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte {
if len(s.msgBuf) < 4 {
return nil
}
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
if len(s.msgBuf) < msgLen {
return nil
}
msg := make([]byte, msgLen)
copy(msg, s.msgBuf[:msgLen])
s.msgBuf = s.msgBuf[msgLen:]
return msg
}
func (s *cryptoStreamImpl) Finish() error {
if s.queue.HasMoreData() {
return errors.New("encryption level changed, but crypto stream has more data to read")
}
s.finished = true
return nil
}
// Writes writes data that should be sent out in CRYPTO frames
func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
s.writeBuf = append(s.writeBuf, p...)
return len(p), nil
}
func (s *cryptoStreamImpl) HasData() bool {
return len(s.writeBuf) > 0
}
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
f := &wire.CryptoFrame{Offset: s.writeOffset}
n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
f.Data = s.writeBuf[:n]
s.writeBuf = s.writeBuf[n:]
s.writeOffset += n
return f
}

View File

@ -0,0 +1,55 @@
package quic
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoDataHandler interface {
HandleMessage([]byte, protocol.EncryptionLevel) bool
}
type cryptoStreamManager struct {
cryptoHandler cryptoDataHandler
initialStream cryptoStream
handshakeStream cryptoStream
}
func newCryptoStreamManager(
cryptoHandler cryptoDataHandler,
initialStream cryptoStream,
handshakeStream cryptoStream,
) *cryptoStreamManager {
return &cryptoStreamManager{
cryptoHandler: cryptoHandler,
initialStream: initialStream,
handshakeStream: handshakeStream,
}
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
var str cryptoStream
switch encLevel {
case protocol.EncryptionInitial:
str = m.initialStream
case protocol.EncryptionHandshake:
str = m.handshakeStream
default:
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
if err := str.HandleCryptoFrame(frame); err != nil {
return false, err
}
for {
data := str.GetCryptoData()
if data == nil {
return false, nil
}
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
return true, str.Finish()
}
}
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

View File

@ -1,65 +0,0 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg width="601px" height="242px" viewBox="0 0 601 242" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<!-- Generator: Sketch 3.7.2 (28276) - http://www.bohemiancoding.com/sketch -->
<title>quic</title>
<desc>Created with Sketch.</desc>
<defs></defs>
<g id="Page-1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="Group-4-Copy" transform="translate(586.087964, 15.755663) rotate(-26.000000) translate(-586.087964, -15.755663) translate(573.036290, 1.253803)">
<path d="M12.0494985,28.9428124 C-12.9338918,21.9172827 5.64340536,-9.86582259 25.705693,3.20698278 L12.0494985,28.9428124 L12.0494985,28.9428124 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.861962, 14.514517) scale(-1, 1) translate(-12.861962, -14.514517) "></path>
<path d="M12.1335889,20.1646293 C8.83752326,18.4210914 6.41944167,16.0531448 8.43561456,12.2408129 C10.3022111,8.71147453 13.769813,9.09473797 17.066293,10.8382759 L12.1335889,20.1646293 L12.1335889,20.1646293 Z" id="Shape" fill="#000000" transform="translate(12.382769, 14.854037) scale(-1, 1) translate(-12.382769, -14.854037) "></path>
</g>
<g id="Group-3-Copy" transform="translate(499.346306, 15.541651) rotate(25.000000) translate(-499.346306, -15.541651) translate(486.346306, 0.541651)">
<path d="M0.262217125,3.72242032 C20.0373679,-10.1674613 37.9073886,21.170644 14.9248407,29.0555125 L0.262217125,3.72242032 L0.262217125,3.72242032 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.997366, 14.636187) scale(-1, 1) translate(-12.997366, -14.636187) "></path>
<path d="M16.7752477,19.7502904 C20.0713133,18.0067525 22.4893949,15.6388059 20.473222,11.8264741 C18.6066255,8.29713567 15.1390235,8.68039911 11.8425436,10.423937 L16.7752477,19.7502904 L16.7752477,19.7502904 Z" id="Shape" fill="#000000" transform="translate(16.526068, 14.439699) scale(-1, 1) translate(-16.526068, -14.439699) "></path>
</g>
<g id="Group-3" transform="translate(1.896652, 21.000000)">
<path d="M0.262217125,3.72242032 C20.0373679,-10.1674613 37.9073886,21.170644 14.9248407,29.0555125 L0.262217125,3.72242032 L0.262217125,3.72242032 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.997366, 14.636187) scale(-1, 1) translate(-12.997366, -14.636187) "></path>
<path d="M16.7752477,19.7502904 C20.0713133,18.0067525 22.4893949,15.6388059 20.473222,11.8264741 C18.6066255,8.29713567 15.1390235,8.68039911 11.8425436,10.423937 L16.7752477,19.7502904 L16.7752477,19.7502904 Z" id="Shape" fill="#000000" transform="translate(16.526068, 14.439699) scale(-1, 1) translate(-16.526068, -14.439699) "></path>
</g>
<g id="Group-4" transform="translate(117.000000, 22.000000)">
<path d="M12.0494985,28.9428124 C-12.9338918,21.9172827 5.64340536,-9.86582259 25.705693,3.20698278 L12.0494985,28.9428124 L12.0494985,28.9428124 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.861962, 14.514517) scale(-1, 1) translate(-12.861962, -14.514517) "></path>
<path d="M12.1335889,20.1646293 C8.83752326,18.4210914 6.41944167,16.0531448 8.43561456,12.2408129 C10.3022111,8.71147453 13.769813,9.09473797 17.066293,10.8382759 L12.1335889,20.1646293 L12.1335889,20.1646293 Z" id="Shape" fill="#000000" transform="translate(12.382769, 14.854037) scale(-1, 1) translate(-12.382769, -14.854037) "></path>
</g>
<g id="Group-2" transform="translate(26.000000, 40.000000)">
<path d="M0.189217386,22.0262538 C4.67982195,48.130845 47.3500951,41.2287883 41.2211947,14.9008684 C35.725404,-8.70815982 -1.30323119,-2.17320732 0.189217386,22.0262538" id="Shape" stroke="#000000" stroke-width="2.9081" fill="#FFFFFF" transform="translate(20.983709, 19.323749) scale(-1, 1) translate(-20.983709, -19.323749) "></path>
<path d="M52.375353,26.2483668 C58.1955709,48.9748532 94.5815666,43.1562926 93.210105,20.3593686 C91.566837,-6.94349061 46.8107821,-1.67517201 52.375353,26.2483668" id="Shape" stroke="#000000" stroke-width="2.8214" fill="#FFFFFF" transform="translate(72.572290, 21.519163) scale(-1, 1) translate(-72.572290, -21.519163) "></path>
<path d="M44.873555,53.3511003 C44.8926146,56.7449499 45.6446396,60.55521 45.0028287,64.1657589 C44.1364461,65.803226 42.4368281,65.9764197 40.9717259,66.6381188 C38.9456089,66.3203209 37.2418475,64.9898788 36.429329,63.0946929 C35.9093337,58.9736786 36.6232396,54.9835954 36.754585,50.8609237 L44.873555,53.3511003 L44.873555,53.3511003 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(40.751798, 58.749521) scale(-1, 1) translate(-40.751798, -58.749521) "></path>
<g id="Group" transform="translate(82.660603, 22.581468) scale(-1, 1) translate(-82.660603, -22.581468) translate(76.238350, 15.744877)">
<ellipse id="Oval" fill="#000000" cx="6.34932869" cy="6.79681466" rx="6.14423095" ry="6.65511077"></ellipse>
<ellipse id="Oval" fill="#FFFFFF" cx="9.1440443" cy="8.29879302" rx="1.44852865" ry="1.69133123"></ellipse>
</g>
<g id="Group" transform="translate(31.282584, 20.924112) scale(-1, 1) translate(-31.282584, -20.924112) translate(24.860332, 14.087521)">
<ellipse id="Oval" fill="#000000" cx="6.45167039" cy="6.79681466" rx="6.04188925" ry="6.65511077"></ellipse>
<ellipse id="Oval" fill="#FFFFFF" cx="9.19998004" cy="8.29879302" rx="1.424497" ry="1.69133123"></ellipse>
</g>
<path d="M45.9112876,52.98151 C43.2305151,59.4783433 47.4062222,72.4699383 54.6799409,62.8875235 C54.1599456,58.7665092 54.8738514,54.776426 55.0051969,50.6537543 L45.9112876,52.98151 L45.9112876,52.98151 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(50.049883, 58.475140) scale(-1, 1) translate(-50.049883, -58.475140) "></path>
<g id="Group" transform="translate(44.127089, 43.712750) scale(-1, 1) translate(-44.127089, -43.712750) translate(27.346365, 32.732770)">
<path d="M7.63667953,7.73694953 C2.64016722,8.16288988 -1.44397093,14.1036805 1.15393372,18.8035261 C4.59418928,25.0285532 12.2731314,18.2528698 17.0558448,18.8876369 C22.5603366,19.0003371 27.0704151,24.7078549 31.4914107,19.9193407 C36.40837,14.593429 29.3745535,9.4063208 23.8771055,7.0872662 L7.63667953,7.73694953 L7.63667953,7.73694953 Z" id="Shape" stroke="#231F20" stroke-width="3" fill="#F6D2A2"></path>
<path d="M7.00771314,7.47674473 C6.63770854,-1.1792084 23.1412397,-2.2614615 25.0902897,4.98408215 C27.0343676,12.2126379 7.82023164,13.891539 7.00771314,7.47674473 C6.35927282,2.3542734 7.00771314,7.47674473 7.00771314,7.47674473 L7.00771314,7.47674473 Z" id="Shape" fill="#000000"></path>
</g>
</g>
<g id="Group-2-Copy" transform="translate(544.820291, 73.354278) scale(-1, 1) translate(-544.820291, -73.354278) translate(498.000000, 40.000000)">
<path d="M0.189217386,22.0262538 C4.67982195,48.130845 47.3500951,41.2287883 41.2211947,14.9008684 C35.725404,-8.70815982 -1.30323119,-2.17320732 0.189217386,22.0262538" id="Shape" stroke="#000000" stroke-width="2.9081" fill="#FFFFFF" transform="translate(20.983709, 19.323749) scale(-1, 1) translate(-20.983709, -19.323749) "></path>
<path d="M52.375353,26.2483668 C58.1955709,48.9748532 94.5815666,43.1562926 93.210105,20.3593686 C91.566837,-6.94349061 46.8107821,-1.67517201 52.375353,26.2483668" id="Shape" stroke="#000000" stroke-width="2.8214" fill="#FFFFFF" transform="translate(72.572290, 21.519163) scale(-1, 1) translate(-72.572290, -21.519163) "></path>
<path d="M44.873555,53.3511003 C44.8926146,56.7449499 45.6446396,60.55521 45.0028287,64.1657589 C44.1364461,65.803226 42.4368281,65.9764197 40.9717259,66.6381188 C38.9456089,66.3203209 37.2418475,64.9898788 36.429329,63.0946929 C35.9093337,58.9736786 36.6232396,54.9835954 36.754585,50.8609237 L44.873555,53.3511003 L44.873555,53.3511003 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(40.751798, 58.749521) scale(-1, 1) translate(-40.751798, -58.749521) "></path>
<g id="Group" transform="translate(82.660603, 22.581468) scale(-1, 1) translate(-82.660603, -22.581468) translate(76.238350, 15.744877)">
<ellipse id="Oval" fill="#000000" cx="6.34932869" cy="6.79681466" rx="6.14423095" ry="6.65511077"></ellipse>
<ellipse id="Oval" fill="#FFFFFF" cx="9.1440443" cy="8.29879302" rx="1.44852865" ry="1.69133123"></ellipse>
</g>
<g id="Group" transform="translate(31.282584, 20.924112) scale(-1, 1) translate(-31.282584, -20.924112) translate(24.860332, 14.087521)">
<ellipse id="Oval" fill="#000000" cx="6.45167039" cy="6.79681466" rx="6.04188925" ry="6.65511077"></ellipse>
<ellipse id="Oval" fill="#FFFFFF" cx="9.19998004" cy="8.29879302" rx="1.424497" ry="1.69133123"></ellipse>
</g>
<path d="M45.9112876,52.98151 C43.2305151,59.4783433 47.4062222,72.4699383 54.6799409,62.8875235 C54.1599456,58.7665092 54.8738514,54.776426 55.0051969,50.6537543 L45.9112876,52.98151 L45.9112876,52.98151 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(50.049883, 58.475140) scale(-1, 1) translate(-50.049883, -58.475140) "></path>
<g id="Group" transform="translate(44.127089, 43.712750) scale(-1, 1) translate(-44.127089, -43.712750) translate(27.346365, 32.732770)">
<path d="M7.63667953,7.73694953 C2.64016722,8.16288988 -1.44397093,14.1036805 1.15393372,18.8035261 C4.59418928,25.0285532 12.2731314,18.2528698 17.0558448,18.8876369 C22.5603366,19.0003371 27.0704151,24.7078549 31.4914107,19.9193407 C36.40837,14.593429 29.3745535,9.4063208 23.8771055,7.0872662 L7.63667953,7.73694953 L7.63667953,7.73694953 Z" id="Shape" stroke="#231F20" stroke-width="3" fill="#F6D2A2"></path>
<path d="M7.00771314,7.47674473 C6.63770854,-1.1792084 23.1412397,-2.2614615 25.0902897,4.98408215 C27.0343676,12.2126379 7.82023164,13.891539 7.00771314,7.47674473 C6.35927282,2.3542734 7.00771314,7.47674473 7.00771314,7.47674473 L7.00771314,7.47674473 Z" id="Shape" fill="#000000"></path>
</g>
</g>
<text id="QUIC" font-family="SourceCodePro-Light, Source Code Pro" font-size="288" font-weight="300" letter-spacing="-20" fill="#000000">
<tspan x="-13.6" y="197">QUI</tspan>
<tspan x="444.8" y="197">C</tspan>
</text>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 11 KiB

View File

@ -156,3 +156,8 @@ func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
s.readPos += protocol.ByteCount(len(data))
return data, s.readPos >= s.finalOffset
}
// HasMoreData says if there is any more data queued at *any* offset.
func (s *frameSorter) HasMoreData() bool {
return len(s.queue) > 0
}

109
vendor/github.com/lucas-clemente/quic-go/framer.go generated vendored Normal file
View File

@ -0,0 +1,109 @@
package quic
import (
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type framer interface {
QueueControlFrame(wire.Frame)
AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount)
AddActiveStream(protocol.StreamID)
AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame
}
type framerI struct {
mutex sync.Mutex
streamGetter streamGetter
version protocol.VersionNumber
activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID
controlFrameMutex sync.Mutex
controlFrames []wire.Frame
}
var _ framer = &framerI{}
func newFramer(
streamGetter streamGetter,
v protocol.VersionNumber,
) framer {
return &framerI{
streamGetter: streamGetter,
activeStreams: make(map[protocol.StreamID]struct{}),
version: v,
}
}
func (f *framerI) QueueControlFrame(frame wire.Frame) {
f.controlFrameMutex.Lock()
f.controlFrames = append(f.controlFrames, frame)
f.controlFrameMutex.Unlock()
}
func (f *framerI) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) {
var length protocol.ByteCount
f.controlFrameMutex.Lock()
for len(f.controlFrames) > 0 {
frame := f.controlFrames[len(f.controlFrames)-1]
frameLen := frame.Length(f.version)
if length+frameLen > maxLen {
break
}
frames = append(frames, frame)
length += frameLen
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
}
f.controlFrameMutex.Unlock()
return frames, length
}
func (f *framerI) AddActiveStream(id protocol.StreamID) {
f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue = append(f.streamQueue, id)
f.activeStreams[id] = struct{}{}
}
f.mutex.Unlock()
}
func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame {
var length protocol.ByteCount
f.mutex.Lock()
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
numActiveStreams := len(f.streamQueue)
for i := 0; i < numActiveStreams; i++ {
if maxLen-length < protocol.MinStreamFrameSize {
break
}
id := f.streamQueue[0]
f.streamQueue = f.streamQueue[1:]
// This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there.
str, err := f.streamGetter.GetOrOpenSendStream(id)
// The stream can be nil if it completed after it said it had data.
if str == nil || err != nil {
delete(f.activeStreams, id)
continue
}
frame, hasMoreData := str.popStreamFrame(maxLen - length)
if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue = append(f.streamQueue, id)
} else { // no more data to send. Stream is not active any more
delete(f.activeStreams, id)
}
if frame == nil { // can happen if the receiveStream was canceled after it said it had data
continue
}
frames = append(frames, frame)
length += frame.Length(f.version)
}
f.mutex.Unlock()
return frames
}

View File

@ -1,314 +0,0 @@
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type roundTripperOpts struct {
DisableCompression bool
}
var dialAddr = quic.DialAddr
// client is a HTTP2 client doing QUIC requests
type client struct {
mutex sync.RWMutex
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
hostname string
handshakeErr error
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
session quic.Session
headerStream quic.Stream
headerErr *qerr.QuicError
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response
logger utils.Logger
}
var _ http.RoundTripper = &client{}
var defaultQuicConfig = &quic.Config{
RequestConnectionIDOmission: true,
KeepAlive: true,
}
// newClient creates a new client
func newClient(
hostname string,
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
tlsConf: tlsConfig,
config: config,
opts: opts,
headerErrored: make(chan struct{}),
dialer: dialer,
logger: utils.DefaultLogger.WithPrefix("client"),
}
}
// dial dials the connection
func (c *client) dial() error {
var err error
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
// once the version has been negotiated, open the header stream
c.headerStream, err = c.session.OpenStream()
if err != nil {
return err
}
c.requestWriter = newRequestWriter(c.headerStream, c.logger)
go c.handleHeaderStream()
return nil
}
func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream)
var err error
for err == nil {
err = c.readResponse(h2framer, decoder)
}
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
c.logger.Debugf("Error handling header stream: %s", err)
}
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request
close(c.headerErrored)
}
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
frame, err := h2framer.ReadFrame()
if err != nil {
return err
}
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
return errors.New("not a headers frame")
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
return fmt.Errorf("cannot read header fields: %s", err.Error())
}
c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
return err
}
responseChan <- rsp
return nil
}
// Roundtrip executes a request and returns a response
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: add port to address, if it doesn't have one
if req.URL.Scheme != "https" {
return nil, errors.New("quic http2: unsupported scheme")
}
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
}
c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
})
if c.handshakeErr != nil {
return nil, c.handshakeErr
}
hasBody := (req.Body != nil)
responseChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync()
if err != nil {
_ = c.closeWithError(err)
return nil, err
}
c.mutex.Lock()
c.responses[dataStream.StreamID()] = responseChan
c.mutex.Unlock()
var requestedGzip bool
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true
}
// TODO: add support for trailers
endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil {
_ = c.closeWithError(err)
return nil, err
}
resc := make(chan error, 1)
if hasBody {
go func() {
resc <- c.writeRequestBody(dataStream, req.Body)
}()
}
var res *http.Response
var receivedResponse bool
var bodySent bool
if !hasBody {
bodySent = true
}
ctx := req.Context()
for !(bodySent && receivedResponse) {
select {
case res = <-responseChan:
receivedResponse = true
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
case err := <-resc:
bodySent = true
if err != nil {
return nil, err
}
case <-ctx.Done():
// error code 6 signals that stream was canceled
dataStream.CancelRead(6)
dataStream.CancelWrite(6)
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
return nil, ctx.Err()
case <-c.headerErrored:
// an error occurred on the header stream
_ = c.closeWithError(c.headerErr)
return nil, c.headerErr
}
}
// TODO: correctly set this variable
var streamEnded bool
isHead := (req.Method == "HEAD")
res = setLength(res, isHead, streamEnded)
if streamEnded || isHead {
res.Body = noBody
} else {
res.Body = dataStream
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = &gzipReader{body: res.Body}
res.Uncompressed = true
}
}
res.Request = req
return res, nil
}
func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
defer func() {
cerr := body.Close()
if err == nil {
// TODO: what to do with dataStream here? Maybe reset it?
err = cerr
}
}()
_, err = io.Copy(dataStream, body)
if err != nil {
// TODO: what to do with dataStream here? Maybe reset it?
return err
}
return dataStream.Close()
}
func (c *client) closeWithError(e error) error {
if c.session == nil {
return nil
}
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
}
// Close closes the client
func (c *client) Close() error {
if c.session == nil {
return nil
}
return c.session.Close()
}
// copied from net/transport.go
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}

View File

@ -1,35 +0,0 @@
package h2quic
// copied from net/transport.go
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
import (
"compress/gzip"
"io"
)
// call gzip.NewReader on the first call to Read
type gzipReader struct {
body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
}
func (gz *gzipReader) Close() error {
return gz.body.Close()
}

View File

@ -1,77 +0,0 @@
package h2quic
import (
"crypto/tls"
"errors"
"net/http"
"net/url"
"strconv"
"strings"
"golang.org/x/net/http2/hpack"
)
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
var path, authority, method, contentLengthStr string
httpHeaders := http.Header{}
for _, h := range headers {
switch h.Name {
case ":path":
path = h.Value
case ":method":
method = h.Value
case ":authority":
authority = h.Value
case "content-length":
contentLengthStr = h.Value
default:
if !h.IsPseudo() {
httpHeaders.Add(h.Name, h.Value)
}
}
}
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
if len(httpHeaders["Cookie"]) > 0 {
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
}
if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
return nil, errors.New(":path, :authority and :method must not be empty")
}
u, err := url.Parse(path)
if err != nil {
return nil, err
}
var contentLength int64
if len(contentLengthStr) > 0 {
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return nil, err
}
}
return &http.Request{
Method: method,
URL: u,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
Header: httpHeaders,
Body: nil,
ContentLength: contentLength,
Host: authority,
RequestURI: path,
TLS: &tls.ConnectionState{},
}, nil
}
func hostnameFromRequest(req *http.Request) string {
if req.URL != nil {
return req.URL.Host
}
return ""
}

View File

@ -1,29 +0,0 @@
package h2quic
import (
"io"
quic "github.com/lucas-clemente/quic-go"
)
type requestBody struct {
requestRead bool
dataStream quic.Stream
}
// make sure the requestBody can be used as a http.Request.Body
var _ io.ReadCloser = &requestBody{}
func newRequestBody(stream quic.Stream) *requestBody {
return &requestBody{dataStream: stream}
}
func (b *requestBody) Read(p []byte) (int, error) {
b.requestRead = true
return b.dataStream.Read(p)
}
func (b *requestBody) Close() error {
// stream's Close() closes the write side, not the read side
return nil
}

View File

@ -1,203 +0,0 @@
package h2quic
import (
"bytes"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type requestWriter struct {
mutex sync.Mutex
headerStream quic.Stream
henc *hpack.Encoder
hbuf bytes.Buffer // HPACK encoder writes into this
logger utils.Logger
}
const defaultUserAgent = "quic-go"
func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
rw := &requestWriter{
headerStream: headerStream,
logger: logger,
}
rw.henc = hpack.NewEncoder(&rw.hbuf)
return rw
}
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
// TODO: add support for trailers
// TODO: add support for gzip compression
// TODO: write continuation frames, if the header frame is too long
w.mutex.Lock()
defer w.mutex.Unlock()
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
h2framer := http2.NewFramer(w.headerStream, nil)
return h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: uint32(dataStreamID),
EndHeaders: true,
EndStream: endStream,
BlockFragment: w.hbuf.Bytes(),
Priority: http2.PriorityParam{Weight: 0xff},
})
}
// the rest of this files is copied from http2.Transport
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
w.hbuf.Reset()
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return nil, err
}
var path string
if req.Method != "CONNECT" {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
}
return nil, fmt.Errorf("invalid request :path %q", orig)
}
}
}
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}
}
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
w.writeHeader(":authority", host)
w.writeHeader(":method", req.Method)
if req.Method != "CONNECT" {
w.writeHeader(":path", path)
w.writeHeader(":scheme", req.URL.Scheme)
}
if trailers != "" {
w.writeHeader("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
switch lowKey {
case "host", "content-length":
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
case "user-agent":
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
w.writeHeader(lowKey, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
w.writeHeader("accept-encoding", "gzip")
}
if !didUA {
w.writeHeader("user-agent", defaultUserAgent)
}
return w.hbuf.Bytes(), nil
}
func (w *requestWriter) writeHeader(name, value string) {
w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
}
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
func actualContentLength(req *http.Request) int64 {
if req.Body == nil {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}

View File

@ -1,95 +0,0 @@
package h2quic
import (
"bytes"
"errors"
"io/ioutil"
"net/http"
"net/textproto"
"strconv"
"strings"
"golang.org/x/net/http2"
)
// copied from net/http2/transport.go
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
var noBody = ioutil.NopCloser(bytes.NewReader(nil))
// from the handleResponse function
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
if f.Truncated {
return nil, errResponseHeaderListSize
}
status := f.PseudoValue("status")
if status == "" {
return nil, errors.New("missing status pseudo header")
}
statusCode, err := strconv.Atoi(status)
if err != nil {
return nil, errors.New("malformed non-numeric status pseudo header")
}
// TODO: handle statusCode == 100
header := make(http.Header)
res := &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: header,
StatusCode: statusCode,
Status: status + " " + http.StatusText(statusCode),
}
for _, hf := range f.RegularFields() {
key := http.CanonicalHeaderKey(hf.Name)
if key == "Trailer" {
t := res.Trailer
if t == nil {
t = make(http.Header)
res.Trailer = t
}
foreachHeaderElement(hf.Value, func(v string) {
t[http.CanonicalHeaderKey(v)] = nil
})
} else {
header[key] = append(header[key], hf.Value)
}
}
return res, nil
}
// continuation of the handleResponse function
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
if !streamEnded || isHead {
res.ContentLength = -1
if clens := res.Header["Content-Length"]; len(clens) == 1 {
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
res.ContentLength = clen64
}
}
}
return res
}
// copied from net/http/server.go
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
return
}
if !strings.Contains(v, ",") {
fn(v)
return
}
for _, f := range strings.Split(v, ",") {
if f = textproto.TrimString(f); f != "" {
fn(f)
}
}
}

View File

@ -1,114 +0,0 @@
package h2quic
import (
"bytes"
"net/http"
"strconv"
"strings"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
type responseWriter struct {
dataStreamID protocol.StreamID
dataStream quic.Stream
headerStream quic.Stream
headerStreamMutex *sync.Mutex
header http.Header
status int // status code passed to WriteHeader
headerWritten bool
logger utils.Logger
}
func newResponseWriter(
headerStream quic.Stream,
headerStreamMutex *sync.Mutex,
dataStream quic.Stream,
dataStreamID protocol.StreamID,
logger utils.Logger,
) *responseWriter {
return &responseWriter{
header: http.Header{},
headerStream: headerStream,
headerStreamMutex: headerStreamMutex,
dataStream: dataStream,
dataStreamID: dataStreamID,
logger: logger,
}
}
func (w *responseWriter) Header() http.Header {
return w.header
}
func (w *responseWriter) WriteHeader(status int) {
if w.headerWritten {
return
}
w.headerWritten = true
w.status = status
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
for k, v := range w.header {
for index := range v {
enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
}
}
w.logger.Infof("Responding with %d", status)
w.headerStreamMutex.Lock()
defer w.headerStreamMutex.Unlock()
h2framer := http2.NewFramer(w.headerStream, nil)
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: uint32(w.dataStreamID),
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
if err != nil {
w.logger.Errorf("could not write h2 header: %s", err.Error())
}
}
func (w *responseWriter) Write(p []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(200)
}
if !bodyAllowedForStatus(w.status) {
return 0, http.ErrBodyNotAllowed
}
return w.dataStream.Write(p)
}
func (w *responseWriter) Flush() {}
// This is a NOP. Use http.Request.Context
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
// test that we implement http.Flusher
var _ http.Flusher = &responseWriter{}
// copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}

View File

@ -1,9 +0,0 @@
package h2quic
import "net/http"
// The CloseNotifier is a deprecated interface, and staticcheck will report that from Go 1.11.
// By defining it in a separate file, we can exclude this file from staticcheck.
// test that we implement http.CloseNotifier
var _ http.CloseNotifier = &responseWriter{}

View File

@ -1,179 +0,0 @@
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
quic "github.com/lucas-clemente/quic-go"
"golang.org/x/net/http/httpguts"
)
type roundTripCloser interface {
http.RoundTripper
io.Closer
}
// RoundTripper implements the http.RoundTripper interface
type RoundTripper struct {
mutex sync.Mutex
// DisableCompression, if true, prevents the Transport from
// requesting compression with an "Accept-Encoding: gzip"
// request header when the Request contains no existing
// Accept-Encoding value. If the Transport requests gzip on
// its own and gets a gzipped response, it's transparently
// decoded in the Response.Body. However, if the user
// explicitly requested gzip it is not automatically
// uncompressed.
DisableCompression bool
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
// QuicConfig is the quic.Config used for dialing new connections.
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddr will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
clients map[string]roundTripCloser
}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may
// create a new QUIC connection. If set true and
// no cached connection is available, RoundTrip
// will return ErrNoCachedConn.
OnlyCachedConn bool
}
var _ roundTripCloser = &RoundTripper{}
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
// RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL")
}
if req.URL.Host == "" {
closeRequestBody(req)
return nil, errors.New("quic: no Host in request URL")
}
if req.Header == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.Header")
}
if req.URL.Scheme == "https" {
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
}
}
}
} else {
closeRequestBody(req)
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
}
if req.Method != "" && !validMethod(req.Method) {
closeRequestBody(req)
return nil, fmt.Errorf("quic: invalid method %q", req.Method)
}
hostname := authorityAddr("https", hostnameFromRequest(req))
cl, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil {
return nil, err
}
return cl.RoundTrip(req)
}
// RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
r.clients = make(map[string]roundTripCloser)
}
client, ok := r.clients[hostname]
if !ok {
if onlyCached {
return nil, ErrNoCachedConn
}
client = newClient(
hostname,
r.TLSClientConfig,
&roundTripperOpts{DisableCompression: r.DisableCompression},
r.QuicConfig,
r.Dial,
)
r.clients[hostname] = client
}
return client, nil
}
// Close closes the QUIC connections that this RoundTripper has used
func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
return nil
}
func closeRequestBody(req *http.Request) {
if req.Body != nil {
req.Body.Close()
}
}
func validMethod(method string) bool {
/*
Method = "OPTIONS" ; Section 9.2
| "GET" ; Section 9.3
| "HEAD" ; Section 9.4
| "POST" ; Section 9.5
| "PUT" ; Section 9.6
| "DELETE" ; Section 9.7
| "TRACE" ; Section 9.8
| "CONNECT" ; Section 9.9
| extension-method
extension-method = token
token = 1*<any CHAR except CTLs or separators>
*/
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// copied from net/http/http.go
func isNotToken(r rune) bool {
return !httpguts.IsTokenRune(r)
}

View File

@ -1,402 +0,0 @@
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
type streamCreator interface {
quic.Session
GetOrOpenStream(protocol.StreamID) (quic.Stream, error)
}
type remoteCloser interface {
CloseRemote(protocol.ByteCount)
}
// allows mocking of quic.Listen and quic.ListenAddr
var (
quicListen = quic.Listen
quicListenAddr = quic.ListenAddr
)
// Server is a HTTP2 server listening for QUIC connections.
type Server struct {
*http.Server
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
// If nil, it uses reasonable default values.
QuicConfig *quic.Config
// Private flag for demo, do not use
CloseAfterFirstRequest bool
port uint32 // used atomically
listenerMutex sync.Mutex
listener quic.Listener
closed bool
supportedVersionsAsString string
logger utils.Logger // will be set by Server.serveImpl()
}
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
func (s *Server) ListenAndServe() error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
return s.serveImpl(s.TLSConfig, nil)
}
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
var err error
certs := make([]tls.Certificate, 1)
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
// We currently only use the cert-related stuff from tls.Config,
// so we don't need to make a full copy.
config := &tls.Config{
Certificates: certs,
}
return s.serveImpl(config, nil)
}
// Serve an existing UDP connection.
func (s *Server) Serve(conn net.PacketConn) error {
return s.serveImpl(s.TLSConfig, conn)
}
func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
s.logger = utils.DefaultLogger.WithPrefix("server")
s.listenerMutex.Lock()
if s.closed {
s.listenerMutex.Unlock()
return errors.New("Server is already closed")
}
if s.listener != nil {
s.listenerMutex.Unlock()
return errors.New("ListenAndServe may only be called once")
}
var ln quic.Listener
var err error
if conn == nil {
ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
} else {
ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
}
if err != nil {
s.listenerMutex.Unlock()
return err
}
s.listener = ln
s.listenerMutex.Unlock()
for {
sess, err := ln.Accept()
if err != nil {
return err
}
go s.handleHeaderStream(sess.(streamCreator))
}
}
func (s *Server) handleHeaderStream(session streamCreator) {
stream, err := session.AcceptStream()
if err != nil {
session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
return
}
hpackDecoder := hpack.NewDecoder(4096, nil)
h2framer := http2.NewFramer(nil, stream)
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
for {
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
// QuicErrors must originate from stream.Read() returning an error.
// In this case, the session has already logged the error, so we don't
// need to log it again.
errorCode := qerr.InternalError
if qerr, ok := err.(*qerr.QuicError); !ok {
errorCode = qerr.ErrorCode
s.logger.Errorf("error handling h2 request: %s", err.Error())
}
session.CloseWithError(quic.ErrorCode(errorCode), err)
return
}
}
}
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
h2frame, err := h2framer.ReadFrame()
if err != nil {
return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
}
var h2headersFrame *http2.HeadersFrame
switch f := h2frame.(type) {
case *http2.PriorityFrame:
// ignore PRIORITY frames
s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
return nil
case *http2.HeadersFrame:
h2headersFrame = f
default:
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
}
if !h2headersFrame.HeadersEnded() {
return errors.New("http2 header continuation not implemented")
}
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
if err != nil {
s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
return err
}
req, err := requestFromHeaders(headers)
if err != nil {
return err
}
if s.logger.Debug() {
s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
} else {
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
}
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
if err != nil {
return err
}
// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request
if dataStream == nil {
return nil
}
// handleRequest should be as non-blocking as possible to minimize
// head-of-line blocking. Potentially blocking code is run in a separate
// goroutine, enabling handleRequest to return before the code is executed.
go func() {
streamEnded := h2headersFrame.StreamEnded()
if streamEnded {
dataStream.(remoteCloser).CloseRemote(0)
streamEnded = true
_, _ = dataStream.Read([]byte{0}) // read the eof
}
req = req.WithContext(dataStream.Context())
reqBody := newRequestBody(dataStream)
req.Body = reqBody
req.RemoteAddr = session.RemoteAddr().String()
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux
}
panicked := false
func() {
defer func() {
if p := recover(); p != nil {
// Copied from net/http/server.go
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
panicked = true
}
}()
handler.ServeHTTP(responseWriter, req)
}()
if panicked {
responseWriter.WriteHeader(500)
} else {
responseWriter.WriteHeader(200)
}
if responseWriter.dataStream != nil {
if !streamEnded && !reqBody.requestRead {
// in gQUIC, the error code doesn't matter, so just use 0 here
responseWriter.dataStream.CancelRead(0)
}
responseWriter.dataStream.Close()
}
if s.CloseAfterFirstRequest {
time.Sleep(100 * time.Millisecond)
session.Close()
}
}()
return nil
}
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) Close() error {
s.listenerMutex.Lock()
defer s.listenerMutex.Unlock()
s.closed = true
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
return err
}
return nil
}
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) CloseGracefully(timeout time.Duration) error {
// TODO: implement
return nil
}
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
func (s *Server) SetQuicHeaders(hdr http.Header) error {
port := atomic.LoadUint32(&s.port)
if port == 0 {
// Extract port from s.Server.Addr
_, portStr, err := net.SplitHostPort(s.Server.Addr)
if err != nil {
return err
}
portInt, err := net.LookupPort("tcp", portStr)
if err != nil {
return err
}
port = uint32(portInt)
atomic.StoreUint32(&s.port, port)
}
if s.supportedVersionsAsString == "" {
var versions []string
for _, v := range protocol.SupportedVersions {
versions = append(versions, v.ToAltSvc())
}
s.supportedVersionsAsString = strings.Join(versions, ",")
}
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
return nil
}
// ListenAndServeQUIC listens on the UDP network address addr and calls the
// handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is
// used when handler is nil.
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
server := &Server{
Server: &http.Server{
Addr: addr,
Handler: handler,
},
}
return server.ListenAndServeTLS(certFile, keyFile)
}
// ListenAndServe listens on the given network address for both, TLS and QUIC
// connetions in parallel. It returns if one of the two returns an error.
// http.DefaultServeMux is used when handler is nil.
// The correct Alt-Svc headers for QUIC are set.
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
// Load certs
var err error
certs := make([]tls.Certificate, 1)
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
// We currently only use the cert-related stuff from tls.Config,
// so we don't need to make a full copy.
config := &tls.Config{
Certificates: certs,
}
// Open the listeners
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
defer udpConn.Close()
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return err
}
tcpConn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return err
}
defer tcpConn.Close()
tlsConn := tls.NewListener(tcpConn, config)
defer tlsConn.Close()
// Start the servers
httpServer := &http.Server{
Addr: addr,
TLSConfig: config,
}
quicServer := &Server{
Server: httpServer,
}
if handler == nil {
handler = http.DefaultServeMux
}
httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
quicServer.SetQuicHeaders(w.Header())
handler.ServeHTTP(w, r)
})
hErr := make(chan error)
qErr := make(chan error)
go func() {
hErr <- httpServer.Serve(tlsConn)
}()
go func() {
qErr <- quicServer.Serve(udpConn)
}()
select {
case err := <-hErr:
quicServer.Close()
return err
case err := <-qErr:
// Cannot close the HTTP server or wait for requests to complete properly :/
return err
}
}

View File

@ -16,19 +16,11 @@ type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber
const (
// VersionGQUIC39 is gQUIC version 39.
VersionGQUIC39 = protocol.Version39
// VersionGQUIC43 is gQUIC version 43.
VersionGQUIC43 = protocol.Version43
// VersionGQUIC44 is gQUIC version 44.
VersionGQUIC44 = protocol.Version44
// VersionMilestone0_10_0 uses TLS
VersionMilestone0_10_0 = protocol.VersionMilestone0_10_0
)
// A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie
type Cookie struct {
RemoteAddr string
SentTime time.Time
}
// ConnectionState records basic details about the QUIC connection.
type ConnectionState = handshake.ConnectionState
@ -166,11 +158,7 @@ type Config struct {
// If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon.
Versions []VersionNumber
// Ask the server to omit the connection ID sent in the Public Header.
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client.
RequestConnectionIDOmission bool
// The length of the connection ID in bytes. Only valid for IETF QUIC.
// The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18.
// If not set, the interpretation depends on where the Config is used:
// If used for dialing an address, a 0 byte connection ID will be used.
@ -200,13 +188,10 @@ type Config struct {
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingStreams int
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// This value doesn't have any effect in Google QUIC.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingUniStreams int
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool

View File

@ -27,11 +27,12 @@ type SentPacketHandler interface {
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
ShouldSendNumPackets() int
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() *Packet
DequeueProbePacket() (*Packet, error)
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber() protocol.PacketNumber
GetAlarmTimeout() time.Time
OnAlarm() error

View File

@ -1,10 +1,11 @@
package quic
package ackhandler
import (
"crypto/rand"
"math"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// The packetNumberGenerator generates the packet number for the next packet
@ -15,13 +16,17 @@ type packetNumberGenerator struct {
next protocol.PacketNumber
nextToSkip protocol.PacketNumber
history []protocol.PacketNumber
}
func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
return &packetNumberGenerator{
g := &packetNumberGenerator{
next: initial,
averagePeriod: averagePeriod,
}
g.generateNewSkip()
return g
}
func (p *packetNumberGenerator) Peek() protocol.PacketNumber {
@ -35,6 +40,10 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber {
p.next++
if p.next == p.nextToSkip {
if len(p.history)+1 > protocol.MaxTrackedSkippedPackets {
p.history = p.history[1:]
}
p.history = append(p.history, p.next)
p.next++
p.generateNewSkip()
}
@ -42,28 +51,28 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber {
return next
}
func (p *packetNumberGenerator) generateNewSkip() error {
num, err := p.getRandomNumber()
if err != nil {
return err
}
func (p *packetNumberGenerator) generateNewSkip() {
num := p.getRandomNumber()
skip := protocol.PacketNumber(num) * (p.averagePeriod - 1) / (math.MaxUint16 / 2)
// make sure that there are never two consecutive packet numbers that are skipped
p.nextToSkip = p.next + 2 + skip
return nil
}
// getRandomNumber() generates a cryptographically secure random number between 0 and MaxUint16 (= 65535)
// The expectation value is 65535/2
func (p *packetNumberGenerator) getRandomNumber() (uint16, error) {
func (p *packetNumberGenerator) getRandomNumber() uint16 {
b := make([]byte, 2)
_, err := rand.Read(b)
if err != nil {
return 0, err
}
rand.Read(b) // ignore the error here
num := uint16(b[0])<<8 + uint16(b[1])
return num, nil
return num
}
func (p *packetNumberGenerator) Validate(ack *wire.AckFrame) bool {
for _, pn := range p.history {
if ack.AcksPacket(pn) {
return false
}
}
return true
}

View File

@ -2,9 +2,9 @@ package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
// The receivedPacketHistory stores if a packet number has already been received.

View File

@ -16,8 +16,6 @@ func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
// IsFrameRetransmittable returns true if the frame should be retransmitted.
func IsFrameRetransmittable(f wire.Frame) bool {
switch f.(type) {
case *wire.StopWaitingFrame:
return false
case *wire.AckFrame:
return false
default:

View File

@ -8,9 +8,9 @@ import (
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
const (
@ -30,12 +30,13 @@ const (
)
type sentPacketHandler struct {
lastSentPacketNumber protocol.PacketNumber
lastSentPacketNumber protocol.PacketNumber
packetNumberGenerator *packetNumberGenerator
lastSentRetransmittablePacketTime time.Time
lastSentHandshakePacketTime time.Time
nextPacketSendTime time.Time
skippedPackets []protocol.PacketNumber
largestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber
@ -45,8 +46,7 @@ type sentPacketHandler struct {
lowestPacketNotConfirmedAcked protocol.PacketNumber
largestSentBeforeRTO protocol.PacketNumber
packetHistory *sentPacketHistory
stopWaitingManager stopWaitingManager
packetHistory *sentPacketHistory
retransmissionQueue []*Packet
@ -90,12 +90,12 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
)
return &sentPacketHandler{
packetHistory: newSentPacketHistory(),
stopWaitingManager: stopWaitingManager{},
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
packetHistory: newSentPacketHistory(),
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
}
}
@ -110,13 +110,13 @@ func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
if packet.EncryptionLevel == protocol.Encryption1RTT {
queue = append(queue, packet)
}
}
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
if p.EncryptionLevel != protocol.Encryption1RTT {
handshakePackets = append(handshakePackets, p)
}
return true, nil
@ -148,10 +148,7 @@ func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retra
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.skippedPackets = append(h.skippedPackets, p)
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
h.skippedPackets = h.skippedPackets[1:]
}
h.logger.Debugf("Skipping packet number %#x", p)
}
h.lastSentPacketNumber = packet.PacketNumber
@ -166,7 +163,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
if packet.EncryptionLevel != protocol.Encryption1RTT {
h.lastSentHandshakePacketTime = packet.SendTime
}
h.lastSentRetransmittablePacketTime = packet.SendTime
@ -198,7 +195,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
h.largestReceivedPacketWithAck = withPacketNumber
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
if h.skippedPacketsAcked(ackFrame) {
if !h.packetNumberGenerator.Validate(ackFrame) {
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
}
@ -213,9 +210,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
priorInFlight := h.bytesInFlight
for _, p := range ackedPackets {
if encLevel < p.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
}
// TODO(#1534): check the encryption level
// if encLevel < p.EncryptionLevel {
// return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
// }
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
@ -234,10 +233,6 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
return err
}
h.updateLossDetectionAlarm()
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
return nil
}
@ -519,12 +514,13 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
return h.DequeuePacketForRetransmission(), nil
}
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
pn := h.packetNumberGenerator.Peek()
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
return h.stopWaitingManager.GetStopWaitingFrame(force)
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {
return h.packetNumberGenerator.Pop()
}
func (h *sentPacketHandler) SendMode() SendMode {
@ -585,7 +581,7 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
if p.canBeRetransmitted && p.EncryptionLevel != protocol.Encryption1RTT {
handshakePackets = append(handshakePackets, p)
}
return true, nil
@ -607,7 +603,6 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
return err
}
h.retransmissionQueue = append(h.retransmissionQueue, p)
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
return nil
}
@ -633,26 +628,6 @@ func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
}
rto = utils.MaxDuration(rto, minRTOTimeout)
// Exponential backoff
rto = rto << h.rtoCount
rto <<= h.rtoCount
return utils.MinDuration(rto, maxRTOTimeout)
}
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
for _, p := range h.skippedPackets {
if ackFrame.AcksPacket(p) {
return true
}
}
return false
}
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
lowestUnacked := h.lowestUnacked()
deleteIndex := 0
for i, p := range h.skippedPackets {
if p < lowestUnacked {
deleteIndex = i + 1
}
}
h.skippedPackets = h.skippedPackets[deleteIndex:]
}

View File

@ -35,7 +35,7 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
}
if p.canBeRetransmitted {
h.numOutstandingPackets++
if p.EncryptionLevel < protocol.EncryptionForwardSecure {
if p.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets++
}
}
@ -106,7 +106,7 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber)
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")
@ -147,7 +147,7 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")

View File

@ -1,43 +0,0 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
type stopWaitingManager struct {
largestLeastUnackedSent protocol.PacketNumber
nextLeastUnacked protocol.PacketNumber
lastStopWaitingFrame *wire.StopWaitingFrame
}
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
if force {
return s.lastStopWaitingFrame
}
return nil
}
s.largestLeastUnackedSent = s.nextLeastUnacked
swf := &wire.StopWaitingFrame{
LeastUnacked: s.nextLeastUnacked,
}
s.lastStopWaitingFrame = swf
return swf
}
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
largestAcked := ack.LargestAcked()
if largestAcked >= s.nextLeastUnacked {
s.nextLeastUnacked = largestAcked + 1
}
}
func (s *stopWaitingManager) QueuedRetransmissionForPacketNumber(p protocol.PacketNumber) {
if p >= s.nextLeastUnacked {
s.nextLeastUnacked = p + 1
}
}

View File

@ -193,7 +193,7 @@ func (c *cubicSender) OnPacketLost(
if c.congestionWindow >= 2*c.initialCongestionWindow {
c.minSlowStartExitWindow = c.congestionWindow / 2
}
c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS
c.congestionWindow -= protocol.DefaultTCPMSS
} else if c.reno {
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
} else {

View File

@ -1,72 +0,0 @@
package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/lucas-clemente/aes12"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadAESGCM12 struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
var _ AEAD = &aeadAESGCM12{}
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
//
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
// tag size, and couples the cipher and aes packages closely.
// See https://github.com/lucas-clemente/aes12.
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
}
encrypterCipher, err := aes12.NewCipher(myKey)
if err != nil {
return nil, err
}
encrypter, err := aes12.NewGCM(encrypterCipher)
if err != nil {
return nil, err
}
decrypterCipher, err := aes12.NewCipher(otherKey)
if err != nil {
return nil, err
}
decrypter, err := aes12.NewGCM(decrypterCipher)
if err != nil {
return nil, err
}
return &aeadAESGCM12{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}
func (aead *aeadAESGCM12) Overhead() int {
return aead.encrypter.Overhead()
}

View File

@ -1,48 +0,0 @@
package crypto
import (
"fmt"
"hash/fnv"
"github.com/hashicorp/golang-lru"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
compressedCertsCache *lru.Cache
)
func getCompressedCert(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
// Hash all inputs
hasher := fnv.New64a()
for _, v := range chain {
hasher.Write(v)
}
hasher.Write(pCommonSetHashes)
hasher.Write(pCachedHashes)
hash := hasher.Sum64()
var result []byte
resultI, isCached := compressedCertsCache.Get(hash)
if isCached {
result = resultI.([]byte)
} else {
var err error
result, err = compressChain(chain, pCommonSetHashes, pCachedHashes)
if err != nil {
return nil, err
}
compressedCertsCache.Add(hash, result)
}
return result, nil
}
func init() {
var err error
compressedCertsCache, err = lru.New(protocol.NumCachedCertificates)
if err != nil {
panic(fmt.Sprintf("fatal error in quic-go: could not create lru cache: %s", err.Error()))
}
}

View File

@ -1,113 +0,0 @@
package crypto
import (
"crypto/tls"
"errors"
"strings"
)
// A CertChain holds a certificate and a private key
type CertChain interface {
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
GetLeafCert(sni string) ([]byte, error)
}
// proofSource stores a key and a certificate for the server proof
type certChain struct {
config *tls.Config
}
var _ CertChain = &certChain{}
var errNoMatchingCertificate = errors.New("no matching certificate found")
// NewCertChain loads the key and cert from files
func NewCertChain(tlsConfig *tls.Config) CertChain {
return &certChain{config: tlsConfig}
}
// SignServerProof signs CHLO and server config for use in the server proof
func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return signServerProof(cert, chlo, serverConfigData)
}
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes)
}
// GetLeafCert gets the leaf certificate
func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return cert.Certificate[0], nil
}
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
conf := c.config
conf, err := maybeGetConfigForClient(conf, sni)
if err != nil {
return nil, err
}
// The rest of this function is mostly copied from crypto/tls.getCertificate
if conf.GetCertificate != nil {
cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if cert != nil || err != nil {
return cert, err
}
}
if len(conf.Certificates) == 0 {
return nil, errNoMatchingCertificate
}
if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
// There's only one choice, so no point doing any work.
return &conf.Certificates[0], nil
}
name := strings.ToLower(sni)
for len(name) > 0 && name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
if cert, ok := conf.NameToCertificate[name]; ok {
return cert, nil
}
// try replacing labels in the name with wildcards until we get a
// match.
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := conf.NameToCertificate[candidate]; ok {
return cert, nil
}
}
// If nothing matches, return the first certificate.
return &conf.Certificates[0], nil
}
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
return c.GetConfigForClient(&tls.ClientHelloInfo{
ServerName: sni,
})
}

View File

@ -1,272 +0,0 @@
package crypto
import (
"bytes"
"compress/flate"
"compress/zlib"
"encoding/binary"
"errors"
"fmt"
"hash/fnv"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type entryType uint8
const (
entryCompressed entryType = 1
entryCached entryType = 2
entryCommon entryType = 3
)
type entry struct {
t entryType
h uint64 // set hash
i uint32 // index
}
func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
res := &bytes.Buffer{}
cachedHashes, err := splitHashes(pCachedHashes)
if err != nil {
return nil, err
}
setHashes, err := splitHashes(pCommonSetHashes)
if err != nil {
return nil, err
}
chainHashes := make([]uint64, len(chain))
for i := range chain {
chainHashes[i] = HashCert(chain[i])
}
entries := buildEntries(chain, chainHashes, cachedHashes, setHashes)
totalUncompressedLen := 0
for i, e := range entries {
res.WriteByte(uint8(e.t))
switch e.t {
case entryCached:
utils.LittleEndian.WriteUint64(res, e.h)
case entryCommon:
utils.LittleEndian.WriteUint64(res, e.h)
utils.LittleEndian.WriteUint32(res, e.i)
case entryCompressed:
totalUncompressedLen += 4 + len(chain[i])
}
}
res.WriteByte(0) // end of list
if totalUncompressedLen > 0 {
gz, err := zlib.NewWriterLevelDict(res, flate.BestCompression, buildZlibDictForEntries(entries, chain))
if err != nil {
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
}
utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen))
for i, e := range entries {
if e.t != entryCompressed {
continue
}
lenCert := len(chain[i])
gz.Write([]byte{
byte(lenCert & 0xff),
byte((lenCert >> 8) & 0xff),
byte((lenCert >> 16) & 0xff),
byte((lenCert >> 24) & 0xff),
})
gz.Write(chain[i])
}
gz.Close()
}
return res.Bytes(), nil
}
func decompressChain(data []byte) ([][]byte, error) {
var chain [][]byte
var entries []entry
r := bytes.NewReader(data)
var numCerts int
var hasCompressedCerts bool
for {
entryTypeByte, err := r.ReadByte()
if entryTypeByte == 0 {
break
}
et := entryType(entryTypeByte)
if err != nil {
return nil, err
}
numCerts++
switch et {
case entryCached:
// we're not sending any certificate hashes in the CHLO, so there shouldn't be any cached certificates in the chain
return nil, errors.New("unexpected cached certificate")
case entryCommon:
e := entry{t: entryCommon}
e.h, err = utils.LittleEndian.ReadUint64(r)
if err != nil {
return nil, err
}
e.i, err = utils.LittleEndian.ReadUint32(r)
if err != nil {
return nil, err
}
certSet, ok := certSets[e.h]
if !ok {
return nil, errors.New("unknown certSet")
}
if e.i >= uint32(len(certSet)) {
return nil, errors.New("certificate not found in certSet")
}
entries = append(entries, e)
chain = append(chain, certSet[e.i])
case entryCompressed:
hasCompressedCerts = true
entries = append(entries, entry{t: entryCompressed})
chain = append(chain, nil)
default:
return nil, errors.New("unknown entryType")
}
}
if numCerts == 0 {
return make([][]byte, 0), nil
}
if hasCompressedCerts {
uncompressedLength, err := utils.LittleEndian.ReadUint32(r)
if err != nil {
fmt.Println(4)
return nil, err
}
zlibDict := buildZlibDictForEntries(entries, chain)
gz, err := zlib.NewReaderDict(r, zlibDict)
if err != nil {
return nil, err
}
defer gz.Close()
var totalLength uint32
var certIndex int
for totalLength < uncompressedLength {
lenBytes := make([]byte, 4)
_, err := gz.Read(lenBytes)
if err != nil {
return nil, err
}
certLen := binary.LittleEndian.Uint32(lenBytes)
cert := make([]byte, certLen)
n, err := gz.Read(cert)
if uint32(n) != certLen && err != nil {
return nil, err
}
for {
if certIndex >= len(entries) {
return nil, errors.New("CertCompression BUG: no element to save uncompressed certificate")
}
if entries[certIndex].t == entryCompressed {
chain[certIndex] = cert
certIndex++
break
}
certIndex++
}
totalLength += 4 + certLen
}
}
return chain, nil
}
func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry {
res := make([]entry, len(chain))
chainLoop:
for i := range chain {
// Check if hash is in cachedHashes
for j := range cachedHashes {
if chainHashes[i] == cachedHashes[j] {
res[i] = entry{t: entryCached, h: chainHashes[i]}
continue chainLoop
}
}
// Go through common sets and check if it's in there
for _, setHash := range setHashes {
set, ok := certSets[setHash]
if !ok {
// We don't have this set
continue
}
// We have this set, check if chain[i] is in the set
pos := set.findCertInSet(chain[i])
if pos >= 0 {
// Found
res[i] = entry{t: entryCommon, h: setHash, i: uint32(pos)}
continue chainLoop
}
}
res[i] = entry{t: entryCompressed}
}
return res
}
func buildZlibDictForEntries(entries []entry, chain [][]byte) []byte {
var dict bytes.Buffer
// First the cached and common in reverse order
for i := len(entries) - 1; i >= 0; i-- {
if entries[i].t == entryCompressed {
continue
}
dict.Write(chain[i])
}
dict.Write(certDictZlib)
return dict.Bytes()
}
func splitHashes(hashes []byte) ([]uint64, error) {
if len(hashes)%8 != 0 {
return nil, errors.New("expected a multiple of 8 bytes for CCS / CCRT hashes")
}
n := len(hashes) / 8
res := make([]uint64, n)
for i := 0; i < n; i++ {
res[i] = binary.LittleEndian.Uint64(hashes[i*8 : (i+1)*8])
}
return res, nil
}
func getCommonCertificateHashes() []byte {
ccs := make([]byte, 8*len(certSets))
i := 0
for certSetHash := range certSets {
binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash)
i++
}
return ccs
}
// HashCert calculates the FNV1a hash of a certificate
func HashCert(cert []byte) uint64 {
h := fnv.New64a()
h.Write(cert)
return h.Sum64()
}

View File

@ -1,128 +0,0 @@
package crypto
var certDictZlib = []byte{
0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04,
0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03,
0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x30,
0x5f, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x04, 0x01,
0x06, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86, 0xfd, 0x6d, 0x01, 0x07,
0x17, 0x01, 0x30, 0x33, 0x20, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65,
0x64, 0x20, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e,
0x20, 0x53, 0x20, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x34,
0x20, 0x53, 0x53, 0x4c, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31,
0x32, 0x20, 0x53, 0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x20, 0x43, 0x41, 0x30, 0x2d, 0x61, 0x69, 0x61, 0x2e,
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
0x2f, 0x45, 0x2d, 0x63, 0x72, 0x6c, 0x2e, 0x76, 0x65, 0x72, 0x69, 0x73,
0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x45, 0x2e, 0x63, 0x65,
0x72, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01,
0x01, 0x05, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x4a, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73,
0x2f, 0x63, 0x70, 0x73, 0x20, 0x28, 0x63, 0x29, 0x30, 0x30, 0x09, 0x06,
0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x30, 0x0d,
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05,
0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x7b, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x0e, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86,
0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01,
0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xd2,
0x6f, 0x64, 0x6f, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x43, 0x2e,
0x63, 0x72, 0x6c, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16,
0x04, 0x14, 0xb4, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69,
0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x30, 0x0b, 0x06, 0x03,
0x55, 0x1d, 0x0f, 0x04, 0x04, 0x03, 0x02, 0x01, 0x30, 0x0d, 0x06, 0x09,
0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x30,
0x81, 0xca, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
0x02, 0x55, 0x53, 0x31, 0x10, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x08,
0x13, 0x07, 0x41, 0x72, 0x69, 0x7a, 0x6f, 0x6e, 0x61, 0x31, 0x13, 0x30,
0x11, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x0a, 0x53, 0x63, 0x6f, 0x74,
0x74, 0x73, 0x64, 0x61, 0x6c, 0x65, 0x31, 0x1a, 0x30, 0x18, 0x06, 0x03,
0x55, 0x04, 0x0a, 0x13, 0x11, 0x47, 0x6f, 0x44, 0x61, 0x64, 0x64, 0x79,
0x2e, 0x63, 0x6f, 0x6d, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31, 0x33,
0x30, 0x31, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x2a, 0x68, 0x74, 0x74,
0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63,
0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79,
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74,
0x6f, 0x72, 0x79, 0x31, 0x30, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x04, 0x03,
0x13, 0x27, 0x47, 0x6f, 0x20, 0x44, 0x61, 0x64, 0x64, 0x79, 0x20, 0x53,
0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66,
0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, 0x68,
0x6f, 0x72, 0x69, 0x74, 0x79, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55,
0x04, 0x05, 0x13, 0x08, 0x30, 0x37, 0x39, 0x36, 0x39, 0x32, 0x38, 0x37,
0x30, 0x1e, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d,
0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x0c,
0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x02, 0x30, 0x00,
0x30, 0x1d, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff,
0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05,
0x05, 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07,
0x03, 0x02, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff,
0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x33, 0x06, 0x03, 0x55, 0x1d,
0x1f, 0x04, 0x2c, 0x30, 0x2a, 0x30, 0x28, 0xa0, 0x26, 0xa0, 0x24, 0x86,
0x22, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, 0x2e,
0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x67, 0x64, 0x73, 0x31, 0x2d, 0x32, 0x30, 0x2a, 0x30, 0x28, 0x06, 0x08,
0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, 0x16, 0x1c, 0x68, 0x74,
0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76, 0x65,
0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63,
0x70, 0x73, 0x30, 0x34, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x5a, 0x17,
0x0d, 0x31, 0x33, 0x30, 0x35, 0x30, 0x39, 0x06, 0x08, 0x2b, 0x06, 0x01,
0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x2d, 0x68, 0x74, 0x74, 0x70, 0x3a,
0x2f, 0x2f, 0x73, 0x30, 0x39, 0x30, 0x37, 0x06, 0x08, 0x2b, 0x06, 0x01,
0x05, 0x05, 0x07, 0x02, 0x30, 0x44, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04,
0x3d, 0x30, 0x3b, 0x30, 0x39, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86,
0xf8, 0x45, 0x01, 0x07, 0x17, 0x06, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03,
0x55, 0x04, 0x06, 0x13, 0x02, 0x47, 0x42, 0x31, 0x1b, 0x53, 0x31, 0x17,
0x30, 0x15, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0e, 0x56, 0x65, 0x72,
0x69, 0x53, 0x69, 0x67, 0x6e, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31,
0x1f, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x16, 0x56, 0x65,
0x72, 0x69, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x54, 0x72, 0x75, 0x73, 0x74,
0x20, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x31, 0x3b, 0x30, 0x39,
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x32, 0x54, 0x65, 0x72, 0x6d, 0x73,
0x20, 0x6f, 0x66, 0x20, 0x75, 0x73, 0x65, 0x20, 0x61, 0x74, 0x20, 0x68,
0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76,
0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x72, 0x70, 0x61, 0x20, 0x28, 0x63, 0x29, 0x30, 0x31, 0x10, 0x30, 0x0e,
0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x07, 0x53, 0x31, 0x13, 0x30, 0x11,
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x0a, 0x47, 0x31, 0x13, 0x30, 0x11,
0x06, 0x0b, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x3c, 0x02, 0x01,
0x03, 0x13, 0x02, 0x55, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04,
0x03, 0x14, 0x31, 0x19, 0x30, 0x17, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13,
0x31, 0x1d, 0x30, 0x1b, 0x06, 0x03, 0x55, 0x04, 0x0f, 0x13, 0x14, 0x50,
0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e,
0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x31, 0x12, 0x31, 0x21, 0x30,
0x1f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x18, 0x44, 0x6f, 0x6d, 0x61,
0x69, 0x6e, 0x20, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x20, 0x56,
0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x64, 0x31, 0x14, 0x31, 0x31,
0x30, 0x2f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x28, 0x53, 0x65, 0x65,
0x20, 0x77, 0x77, 0x77, 0x2e, 0x72, 0x3a, 0x2f, 0x2f, 0x73, 0x65, 0x63,
0x75, 0x72, 0x65, 0x2e, 0x67, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53,
0x69, 0x67, 0x6e, 0x31, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x41,
0x2e, 0x63, 0x72, 0x6c, 0x56, 0x65, 0x72, 0x69, 0x53, 0x69, 0x67, 0x6e,
0x20, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x20, 0x33, 0x20, 0x45, 0x63, 0x72,
0x6c, 0x2e, 0x67, 0x65, 0x6f, 0x74, 0x72, 0x75, 0x73, 0x74, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x63, 0x72, 0x6c, 0x73, 0x2f, 0x73, 0x64, 0x31, 0x1a,
0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x3a,
0x2f, 0x2f, 0x45, 0x56, 0x49, 0x6e, 0x74, 0x6c, 0x2d, 0x63, 0x63, 0x72,
0x74, 0x2e, 0x67, 0x77, 0x77, 0x77, 0x2e, 0x67, 0x69, 0x63, 0x65, 0x72,
0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x31, 0x6f, 0x63, 0x73, 0x70, 0x2e,
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
0x30, 0x39, 0x72, 0x61, 0x70, 0x69, 0x64, 0x73, 0x73, 0x6c, 0x2e, 0x63,
0x6f, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72,
0x79, 0x2f, 0x30, 0x81, 0x80, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05,
0x07, 0x01, 0x01, 0x04, 0x74, 0x30, 0x72, 0x30, 0x24, 0x06, 0x08, 0x2b,
0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, 0x18, 0x68, 0x74, 0x74,
0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, 0x73, 0x70, 0x2e, 0x67, 0x6f, 0x64,
0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x30, 0x4a, 0x06,
0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x3e, 0x68,
0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66,
0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64,
0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73,
0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, 0x67, 0x64, 0x5f, 0x69, 0x6e, 0x74,
0x65, 0x72, 0x6d, 0x65, 0x64, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, 0x72,
0x74, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16,
0x80, 0x14, 0xfd, 0xac, 0x61, 0x32, 0x93, 0x6c, 0x45, 0xd6, 0xe2, 0xee,
0x85, 0x5f, 0x9a, 0xba, 0xe7, 0x76, 0x99, 0x68, 0xcc, 0xe7, 0x30, 0x27,
0x86, 0x29, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x86, 0x30,
0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73,
}

View File

@ -1,135 +0,0 @@
package crypto
import (
"crypto/tls"
"crypto/x509"
"errors"
"hash/fnv"
"time"
"github.com/lucas-clemente/quic-go/qerr"
)
// CertManager manages the certificates sent by the server
type CertManager interface {
SetData([]byte) error
GetCommonCertificateHashes() []byte
GetLeafCert() []byte
GetLeafCertHash() (uint64, error)
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
Verify(hostname string) error
GetChain() []*x509.Certificate
}
type certManager struct {
chain []*x509.Certificate
config *tls.Config
}
var _ CertManager = &certManager{}
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded")
// NewCertManager creates a new CertManager
func NewCertManager(tlsConfig *tls.Config) CertManager {
return &certManager{config: tlsConfig}
}
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
func (c *certManager) SetData(data []byte) error {
byteChain, err := decompressChain(data)
if err != nil {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
}
chain := make([]*x509.Certificate, len(byteChain))
for i, data := range byteChain {
cert, err := x509.ParseCertificate(data)
if err != nil {
return err
}
chain[i] = cert
}
c.chain = chain
return nil
}
func (c *certManager) GetChain() []*x509.Certificate {
return c.chain
}
func (c *certManager) GetCommonCertificateHashes() []byte {
return getCommonCertificateHashes()
}
// GetLeafCert returns the leaf certificate of the certificate chain
// it returns nil if the certificate chain has not yet been set
func (c *certManager) GetLeafCert() []byte {
if len(c.chain) == 0 {
return nil
}
return c.chain[0].Raw
}
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
func (c *certManager) GetLeafCertHash() (uint64, error) {
leafCert := c.GetLeafCert()
if leafCert == nil {
return 0, errNoCertificateChain
}
h := fnv.New64a()
_, err := h.Write(leafCert)
if err != nil {
return 0, err
}
return h.Sum64(), nil
}
// VerifyServerProof verifies the signature of the server config
// it should only be called after the certificate chain has been set, otherwise it returns false
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool {
if len(c.chain) == 0 {
return false
}
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData)
}
// Verify verifies the certificate chain
func (c *certManager) Verify(hostname string) error {
if len(c.chain) == 0 {
return errNoCertificateChain
}
if c.config != nil && c.config.InsecureSkipVerify {
return nil
}
leafCert := c.chain[0]
var opts x509.VerifyOptions
if c.config != nil {
opts.Roots = c.config.RootCAs
if c.config.Time == nil {
opts.CurrentTime = time.Now()
} else {
opts.CurrentTime = c.config.Time()
}
}
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
opts.DNSName = hostname
// the first certificate is the leaf certificate, all others are intermediates
if len(c.chain) > 1 {
intermediates := x509.NewCertPool()
for i := 1; i < len(c.chain); i++ {
intermediates.AddCert(c.chain[i])
}
opts.Intermediates = intermediates
}
_, err := leafCert.Verify(opts)
return err
}

View File

@ -1,24 +0,0 @@
package crypto
import (
"bytes"
"github.com/lucas-clemente/quic-go-certificates"
)
type certSet [][]byte
var certSets = map[uint64]certSet{
certsets.CertSet2Hash: certsets.CertSet2,
certsets.CertSet3Hash: certsets.CertSet3,
}
// findCertInSet searches for the cert in the set. Negative return value means not found.
func (s *certSet) findCertInSet(cert []byte) int {
for i, c := range *s {
if bytes.Equal(c, cert) {
return i
}
}
return -1
}

View File

@ -1,61 +0,0 @@
// +build ignore
package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/aead/chacha20"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadChacha20Poly1305 struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305
func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs")
}
// copy because ChaCha20Poly1305 expects array pointers
var MyKey, OtherKey [32]byte
copy(MyKey[:], myKey)
copy(OtherKey[:], otherKey)
encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12)
if err != nil {
return nil, err
}
decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12)
if err != nil {
return nil, err
}
return &aeadChacha20Poly1305{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}

View File

@ -1,41 +0,0 @@
package crypto
import (
"crypto/rand"
"errors"
"golang.org/x/crypto/curve25519"
)
// KeyExchange manages the exchange of keys
type curve25519KEX struct {
secret [32]byte
public [32]byte
}
var _ KeyExchange = &curve25519KEX{}
// NewCurve25519KEX creates a new KeyExchange using Curve25519, see https://cr.yp.to/ecdh.html
func NewCurve25519KEX() (KeyExchange, error) {
c := &curve25519KEX{}
if _, err := rand.Read(c.secret[:]); err != nil {
return nil, errors.New("Curve25519: could not create private key")
}
curve25519.ScalarBaseMult(&c.public, &c.secret)
return c, nil
}
func (c *curve25519KEX) PublicKey() []byte {
return c.public[:]
}
func (c *curve25519KEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
if len(otherPublic) != 32 {
return nil, errors.New("Curve25519: expected public key of 32 byte")
}
var res [32]byte
var otherPublicArray [32]byte
copy(otherPublicArray[:], otherPublic)
curve25519.ScalarMult(&res, &c.secret, &otherPublicArray)
return res[:], nil
}

View File

@ -0,0 +1,58 @@
package crypto
import (
"crypto"
"crypto/hmac"
"encoding/binary"
)
// copied from https://github.com/cloudflare/tls-tris/blob/master/hkdf.go
func hkdfExtract(hash crypto.Hash, secret, salt []byte) []byte {
if salt == nil {
salt = make([]byte, hash.Size())
}
if secret == nil {
secret = make([]byte, hash.Size())
}
extractor := hmac.New(hash.New, salt)
extractor.Write(secret)
return extractor.Sum(nil)
}
// copied from https://github.com/cloudflare/tls-tris/blob/master/hkdf.go
func hkdfExpand(hash crypto.Hash, prk, info []byte, l int) []byte {
var (
expander = hmac.New(hash.New, prk)
res = make([]byte, l)
counter = byte(1)
prev []byte
)
if l > 255*expander.Size() {
panic("hkdf: requested too much output")
}
p := res
for len(p) > 0 {
expander.Reset()
expander.Write(prev)
expander.Write(info)
expander.Write([]byte{counter})
prev = expander.Sum(prev[:0])
counter++
n := copy(p, prev)
p = p[n:]
}
return res
}
// hkdfExpandLabel HKDF expands a label
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, length int) []byte {
const prefix = "quic "
qlabel := make([]byte, 2 /* length */ +1 /* length of label */ +len(prefix)+len(label)+1 /* length of context (empty) */)
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
qlabel[2] = uint8(len(prefix) + len(label))
copy(qlabel[3:], []byte(prefix+label))
return hkdfExpand(hash, secret, qlabel, length)
}

View File

@ -1,60 +0,0 @@
package crypto
import (
"crypto"
"encoding/binary"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
const (
clientExporterLabel = "EXPORTER-QUIC client 1rtt"
serverExporterLabel = "EXPORTER-QUIC server 1rtt"
)
// A TLSExporter gets the negotiated ciphersuite and computes exporter
type TLSExporter interface {
ConnectionState() mint.ConnectionState
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
}
func qhkdfExpand(secret []byte, label string, length int) []byte {
qlabel := make([]byte, 2+1+5+len(label))
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
qlabel[2] = uint8(5 + len(label))
copy(qlabel[3:], []byte("QUIC "+label))
return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length)
}
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
var myLabel, otherLabel string
if pers == protocol.PerspectiveClient {
myLabel = clientExporterLabel
otherLabel = serverExporterLabel
} else {
myLabel = serverExporterLabel
otherLabel = clientExporterLabel
}
myKey, myIV, err := computeKeyAndIV(tls, myLabel)
if err != nil {
return nil, err
}
otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel)
if err != nil {
return nil, err
}
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
}
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
cs := tls.ConnectionState().CipherSuite
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
if err != nil {
return nil, nil, err
}
key = qhkdfExpand(secret, "key", cs.KeyLen)
iv = qhkdfExpand(secret, "iv", cs.IvLen)
return key, iv, nil
}

View File

@ -1,100 +0,0 @@
package crypto
import (
"bytes"
"crypto/sha256"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"golang.org/x/crypto/hkdf"
)
// DeriveKeysChacha20 derives the client and server keys and creates a matching chacha20poly1305 AEAD instance
// func DeriveKeysChacha20(version protocol.VersionNumber, forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte) (AEAD, error) {
// otherKey, myKey, otherIV, myIV, err := deriveKeys(version, forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 32)
// if err != nil {
// return nil, err
// }
// return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV)
// }
// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance
func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
var swap bool
if pers == protocol.PerspectiveClient {
swap = true
}
otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16, swap)
if err != nil {
return nil, err
}
return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV)
}
// deriveKeys derives the keys and the IVs
// swap should be set true if generating the values for the client, and false for the server
func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int, swap bool) ([]byte, []byte, []byte, []byte, error) {
var info bytes.Buffer
if forwardSecure {
info.Write([]byte("QUIC forward secure key expansion\x00"))
} else {
info.Write([]byte("QUIC key expansion\x00"))
}
info.Write(connID)
info.Write(chlo)
info.Write(scfg)
info.Write(cert)
r := hkdf.New(sha256.New, sharedSecret, nonces, info.Bytes())
s := make([]byte, 2*keyLen+2*4)
if _, err := io.ReadFull(r, s); err != nil {
return nil, nil, nil, nil, err
}
key1 := s[:keyLen]
key2 := s[keyLen : 2*keyLen]
iv1 := s[2*keyLen : 2*keyLen+4]
iv2 := s[2*keyLen+4:]
var otherKey, myKey []byte
var otherIV, myIV []byte
if !forwardSecure {
if err := diversify(key2, iv2, divNonce); err != nil {
return nil, nil, nil, nil, err
}
}
if swap {
otherKey = key2
myKey = key1
otherIV = iv2
myIV = iv1
} else {
otherKey = key1
myKey = key2
otherIV = iv1
myIV = iv2
}
return otherKey, myKey, otherIV, myIV, nil
}
func diversify(key, iv, divNonce []byte) error {
secret := make([]byte, len(key)+len(iv))
copy(secret, key)
copy(secret[len(key):], iv)
r := hkdf.New(sha256.New, secret, divNonce, []byte("QUIC key diversification"))
if _, err := io.ReadFull(r, key); err != nil {
return err
}
if _, err := io.ReadFull(r, iv); err != nil {
return err
}
return nil
}

View File

@ -1,7 +0,0 @@
package crypto
// KeyExchange manages the exchange of keys
type KeyExchange interface {
PublicKey() []byte
CalculateSharedKey(otherPublic []byte) ([]byte, error)
}

View File

@ -1,11 +0,0 @@
package crypto
import "github.com/lucas-clemente/quic-go/internal/protocol"
// NewNullAEAD creates a NullAEAD
func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) {
if v.UsesTLS() {
return newNullAEADAESGCM(connID, p)
}
return &nullAEADFNV128a{perspective: p}, nil
}

View File

@ -3,13 +3,13 @@ package crypto
import (
"crypto"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
// NewNullAEAD creates a NullAEAD
func NewNullAEAD(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
clientSecret, serverSecret := computeSecrets(connectionID)
var mySecret, otherSecret []byte
@ -28,14 +28,14 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
}
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
initialSecret := hkdfExtract(crypto.SHA256, connID, quicVersion1Salt)
clientSecret = HkdfExpandLabel(crypto.SHA256, initialSecret, "client in", crypto.SHA256.Size())
serverSecret = HkdfExpandLabel(crypto.SHA256, initialSecret, "server in", crypto.SHA256.Size())
return
}
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
key = qhkdfExpand(secret, "key", 16)
iv = qhkdfExpand(secret, "iv", 12)
key = HkdfExpandLabel(crypto.SHA256, secret, "key", 16)
iv = HkdfExpandLabel(crypto.SHA256, secret, "iv", 12)
return
}

View File

@ -1,79 +0,0 @@
package crypto
import (
"bytes"
"errors"
"fmt"
"hash/fnv"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// nullAEAD handles not-yet encrypted packets
type nullAEADFNV128a struct {
perspective protocol.Perspective
}
var _ AEAD = &nullAEADFNV128a{}
// Open and verify the ciphertext
func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
if len(src) < 12 {
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
}
hash := fnv.New128a()
hash.Write(associatedData)
hash.Write(src[12:])
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Client"))
} else {
hash.Write([]byte("Server"))
}
sum := make([]byte, 0, 16)
sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
if !bytes.Equal(sum[:12], src[:12]) {
return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
}
return src[12:], nil
}
// Seal writes hash and ciphertext to the buffer
func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
if cap(dst) < 12+len(src) {
dst = make([]byte, 12+len(src))
} else {
dst = dst[:12+len(src)]
}
hash := fnv.New128a()
hash.Write(associatedData)
hash.Write(src)
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Server"))
} else {
hash.Write([]byte("Client"))
}
sum := make([]byte, 0, 16)
sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
copy(dst[12:], src)
copy(dst, sum[:12])
return dst
}
func (n *nullAEADFNV128a) Overhead() int {
return 12
}
func reverse(a []byte) {
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
a[left], a[right] = a[right], a[left]
}
}

View File

@ -1,66 +0,0 @@
package crypto
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"errors"
"math/big"
)
type ecdsaSignature struct {
R, S *big.Int
}
// signServerProof signs CHLO and server config for use in the server proof
func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) {
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
key, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, errors.New("expected PrivateKey to implement crypto.Signer")
}
opts := crypto.SignerOpts(crypto.SHA256)
if _, ok = key.(*rsa.PrivateKey); ok {
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
}
return key.Sign(rand.Reader, hash.Sum(nil), opts)
}
// verifyServerProof verifies the server proof signature
func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool {
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
// RSA
if cert.PublicKeyAlgorithm == x509.RSA {
opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts)
return err == nil
}
// ECDSA
signature := &ecdsaSignature{}
rest, err := asn1.Unmarshal(proof, signature)
if err != nil || len(rest) != 0 {
return false
}
return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S)
}

View File

@ -5,8 +5,8 @@ import (
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type connectionFlowController struct {

View File

@ -19,7 +19,7 @@ type StreamFlowController interface {
flowController
// for receiving
// UpdateHighestReceived should be called when a new highest offset is received
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
}

View File

@ -5,8 +5,8 @@ import (
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type streamFlowController struct {
@ -16,8 +16,7 @@ type streamFlowController struct {
queueWindowUpdate func()
connection connectionFlowControllerI
contributesToConnection bool // does the stream contribute to connection level flow control
connection connectionFlowControllerI
receivedFinalOffset bool
}
@ -27,7 +26,6 @@ var _ StreamFlowController = &streamFlowController{}
// NewStreamFlowController gets a new flow controller for a stream
func NewStreamFlowController(
streamID protocol.StreamID,
contributesToConnection bool,
cfc ConnectionFlowController,
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
@ -37,10 +35,9 @@ func NewStreamFlowController(
logger utils.Logger,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
contributesToConnection: contributesToConnection,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
@ -87,32 +84,21 @@ func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCou
if c.checkFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow))
}
if c.contributesToConnection {
return c.connection.IncrementHighestReceived(increment)
}
return nil
return c.connection.IncrementHighestReceived(increment)
}
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
c.baseFlowController.AddBytesRead(n)
if c.contributesToConnection {
c.connection.AddBytesRead(n)
}
c.connection.AddBytesRead(n)
}
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
c.baseFlowController.AddBytesSent(n)
if c.contributesToConnection {
c.connection.AddBytesSent(n)
}
c.connection.AddBytesSent(n)
}
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
window := c.baseFlowController.sendWindowSize()
if c.contributesToConnection {
window = utils.MinByteCount(window, c.connection.SendWindowSize())
}
return window
return utils.MinByteCount(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
}
func (c *streamFlowController) MaybeQueueWindowUpdate() {
@ -122,9 +108,7 @@ func (c *streamFlowController) MaybeQueueWindowUpdate() {
if hasWindowUpdate {
c.queueWindowUpdate()
}
if c.contributesToConnection {
c.connection.MaybeQueueWindowUpdate()
}
c.connection.MaybeQueueWindowUpdate()
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
@ -140,9 +124,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
if c.contributesToConnection {
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}
c.mutex.Unlock()
return offset

View File

@ -0,0 +1,58 @@
package handshake
import (
"crypto/cipher"
"encoding/binary"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type sealer struct {
iv []byte
aead cipher.AEAD
// use a single slice to avoid allocations
nonceBuf []byte
}
var _ Sealer = &sealer{}
func newSealer(aead cipher.AEAD, iv []byte) Sealer {
return &sealer{
iv: iv,
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
}
}
func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
return s.aead.Seal(dst, s.nonceBuf, src, ad)
}
func (s *sealer) Overhead() int {
return s.aead.Overhead()
}
type opener struct {
iv []byte
aead cipher.AEAD
// use a single slice to avoid allocations
nonceBuf []byte
}
var _ Opener = &opener{}
func newOpener(aead cipher.AEAD, iv []byte) Opener {
return &opener{
iv: iv,
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
}
}
func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
return o.aead.Open(dst, o.nonceBuf, src, ad)
}

View File

@ -5,6 +5,8 @@ import (
"fmt"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
const (
@ -14,14 +16,17 @@ const (
// A Cookie is derived from the client address and can be used to verify the ownership of this address.
type Cookie struct {
RemoteAddr string
// The time that the STK was issued (resolution 1 second)
RemoteAddr string
OriginalDestConnectionID protocol.ConnectionID
// The time that the Cookie was issued (resolution 1 second)
SentTime time.Time
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
Data []byte
RemoteAddr []byte
OriginalDestConnectionID []byte
Timestamp int64
}
@ -42,10 +47,11 @@ func NewCookieGenerator() (*CookieGenerator, error) {
}
// NewToken generates a new Cookie for a given source address
func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
func (g *CookieGenerator) NewToken(raddr net.Addr, origConnID protocol.ConnectionID) ([]byte, error) {
data, err := asn1.Marshal(token{
Data: encodeRemoteAddr(raddr),
Timestamp: time.Now().Unix(),
RemoteAddr: encodeRemoteAddr(raddr),
OriginalDestConnectionID: origConnID,
Timestamp: time.Now().Unix(),
})
if err != nil {
return nil, err
@ -72,10 +78,14 @@ func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
return &Cookie{
RemoteAddr: decodeRemoteAddr(t.Data),
cookie := &Cookie{
RemoteAddr: decodeRemoteAddr(t.RemoteAddr),
SentTime: time.Unix(t.Timestamp, 0),
}, nil
}
if len(t.OriginalDestConnectionID) > 0 {
cookie.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID)
}
return cookie, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie

View File

@ -0,0 +1,515 @@
package handshake
import (
"crypto/tls"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
)
type messageType uint8
// TLS handshake message types.
const (
typeClientHello messageType = 1
typeServerHello messageType = 2
typeEncryptedExtensions messageType = 8
typeCertificate messageType = 11
typeCertificateRequest messageType = 13
typeCertificateVerify messageType = 15
typeFinished messageType = 20
)
func (m messageType) String() string {
switch m {
case typeClientHello:
return "ClientHello"
case typeServerHello:
return "ServerHello"
case typeEncryptedExtensions:
return "EncryptedExtensions"
case typeCertificate:
return "Certificate"
case typeCertificateRequest:
return "CertificateRequest"
case typeCertificateVerify:
return "CertificateVerify"
case typeFinished:
return "Finished"
default:
return fmt.Sprintf("unknown message type: %d", m)
}
}
type cryptoSetup struct {
tlsConf *qtls.Config
messageChan chan []byte
readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel
handleParamsCallback func(*TransportParameters)
// There are two ways that an error can occur during the handshake:
// 1. as a return value from qtls.Handshake()
// 2. when new data is passed to the crypto setup via HandleData()
// handshakeErrChan is closed when qtls.Handshake() errors
handshakeErrChan chan struct{}
// HandleData() sends errors on the messageErrChan
messageErrChan chan error
// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
handshakeDone chan struct{}
// transport parameters are sent on the receivedTransportParams, as soon as they are received
receivedTransportParams <-chan TransportParameters
// is closed when Close() is called
closeChan chan struct{}
clientHelloWritten bool
clientHelloWrittenChan chan struct{}
initialStream io.Writer
initialAEAD crypto.AEAD
handshakeStream io.Writer
handshakeOpener Opener
handshakeSealer Sealer
opener Opener
sealer Sealer
// TODO: add a 1-RTT stream (used for session tickets)
receivedWriteKey chan struct{}
receivedReadKey chan struct{}
logger utils.Logger
perspective protocol.Perspective
}
var _ qtls.RecordLayer = &cryptoSetup{}
var _ CryptoSetup = &cryptoSetup{}
// NewCryptoSetupClient creates a new crypto setup for the client
func NewCryptoSetupClient(
initialStream io.Writer,
handshakeStream io.Writer,
origConnID protocol.ConnectionID,
connID protocol.ConnectionID,
params *TransportParameters,
handleParams func(*TransportParameters),
tlsConf *tls.Config,
initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber,
currentVersion protocol.VersionNumber,
logger utils.Logger,
perspective protocol.Perspective,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
extHandler, receivedTransportParams := newExtensionHandlerClient(
params,
origConnID,
initialVersion,
supportedVersions,
currentVersion,
logger,
)
return newCryptoSetup(
initialStream,
handshakeStream,
connID,
extHandler,
receivedTransportParams,
handleParams,
tlsConf,
logger,
perspective,
)
}
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
params *TransportParameters,
handleParams func(*TransportParameters),
tlsConf *tls.Config,
supportedVersions []protocol.VersionNumber,
currentVersion protocol.VersionNumber,
logger utils.Logger,
perspective protocol.Perspective,
) (CryptoSetup, error) {
extHandler, receivedTransportParams := newExtensionHandlerServer(
params,
supportedVersions,
currentVersion,
logger,
)
cs, _, err := newCryptoSetup(
initialStream,
handshakeStream,
connID,
extHandler,
receivedTransportParams,
handleParams,
tlsConf,
logger,
perspective,
)
return cs, err
}
func newCryptoSetup(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
extHandler tlsExtensionHandler,
transportParamChan <-chan TransportParameters,
handleParams func(*TransportParameters),
tlsConf *tls.Config,
logger utils.Logger,
perspective protocol.Perspective,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
initialAEAD, err := crypto.NewNullAEAD(connID, perspective)
if err != nil {
return nil, nil, err
}
cs := &cryptoSetup{
initialStream: initialStream,
initialAEAD: initialAEAD,
handshakeStream: handshakeStream,
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
handleParamsCallback: handleParams,
receivedTransportParams: transportParamChan,
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
handshakeErrChan: make(chan struct{}),
messageErrChan: make(chan error, 1),
clientHelloWrittenChan: make(chan struct{}),
messageChan: make(chan []byte, 100),
receivedReadKey: make(chan struct{}),
receivedWriteKey: make(chan struct{}),
closeChan: make(chan struct{}),
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf)
qtlsConf.AlternativeRecordLayer = cs
qtlsConf.GetExtensions = extHandler.GetExtensions
qtlsConf.ReceivedExtensions = extHandler.ReceivedExtensions
cs.tlsConf = qtlsConf
return cs, cs.clientHelloWrittenChan, nil
}
func (h *cryptoSetup) RunHandshake() error {
var conn *qtls.Conn
switch h.perspective {
case protocol.PerspectiveClient:
conn = qtls.Client(nil, h.tlsConf)
case protocol.PerspectiveServer:
conn = qtls.Server(nil, h.tlsConf)
}
// Handle errors that might occur when HandleData() is called.
handshakeErrChan := make(chan error, 1)
handshakeComplete := make(chan struct{})
go func() {
defer close(h.handshakeDone)
if err := conn.Handshake(); err != nil {
handshakeErrChan <- err
return
}
close(handshakeComplete)
}()
select {
case <-h.closeChan:
close(h.messageChan)
// wait until the Handshake() go routine has returned
<-handshakeErrChan
return errors.New("Handshake aborted")
case <-handshakeComplete: // return when the handshake is done
return nil
case err := <-handshakeErrChan:
// if handleMessageFor{server,client} are waiting for some qtls action, make them return
close(h.handshakeErrChan)
return err
case err := <-h.messageErrChan:
// If the handshake errored because of an error that occurred during HandleData(),
// that error message will be more useful than the error message generated by Handshake().
// Close the message chan that qtls is receiving messages from.
// This will make qtls.Handshake() return.
// Thereby the go routine running qtls.Handshake() will return.
close(h.messageChan)
return err
}
}
func (h *cryptoSetup) Close() error {
close(h.closeChan)
// wait until qtls.Handshake() actually returned
<-h.handshakeDone
return nil
}
// handleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available.
// It returns if it is done with messages on the same encryption level.
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
msgType := messageType(data[0])
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
h.messageErrChan <- err
return false
}
h.messageChan <- data
switch h.perspective {
case protocol.PerspectiveClient:
return h.handleMessageForClient(msgType)
case protocol.PerspectiveServer:
return h.handleMessageForServer(msgType)
default:
panic("")
}
}
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
var expected protocol.EncryptionLevel
switch msgType {
case typeClientHello,
typeServerHello:
expected = protocol.EncryptionInitial
case typeEncryptedExtensions,
typeCertificate,
typeCertificateRequest,
typeCertificateVerify,
typeFinished:
expected = protocol.EncryptionHandshake
default:
return fmt.Errorf("unexpected handshake message: %d", msgType)
}
if encLevel != expected {
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
}
return nil
}
func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
switch msgType {
case typeClientHello:
select {
case params := <-h.receivedTransportParams:
h.handleParamsCallback(&params)
case <-h.handshakeErrChan:
return false
}
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return false
}
// get the 1-RTT write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return false
}
// get the handshake read key
// TODO: check that the initial stream doesn't have any more data
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return false
}
return true
case typeCertificate, typeCertificateVerify:
// nothing to do
return false
case typeFinished:
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return false
}
return true
default:
panic("unexpected handshake message")
}
}
func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
switch msgType {
case typeServerHello:
// get the handshake read key
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return false
}
return true
case typeEncryptedExtensions:
select {
case params := <-h.receivedTransportParams:
h.handleParamsCallback(&params)
case <-h.handshakeErrChan:
return false
}
return false
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
// nothing to do
return false
case typeFinished:
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return false
}
// While the order of these two is not defined by the TLS spec,
// we have to do it on the same order as our TLS library does it.
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeErrChan:
return false
}
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeErrChan:
return false
}
return true
default:
panic("unexpected handshake message: ")
}
}
// ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available.
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
// TODO: add some error handling here (when the session is closed)
msg, ok := <-h.messageChan
if !ok {
return nil, errors.New("error while handling the handshake message")
}
return msg, nil
}
func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
opener := newOpener(suite.AEAD(key, iv), iv)
switch h.readEncLevel {
case protocol.EncryptionInitial:
h.readEncLevel = protocol.EncryptionHandshake
h.handshakeOpener = opener
h.logger.Debugf("Installed Handshake Read keys")
case protocol.EncryptionHandshake:
h.readEncLevel = protocol.Encryption1RTT
h.opener = opener
h.logger.Debugf("Installed 1-RTT Read keys")
default:
panic("unexpected read encryption level")
}
h.receivedReadKey <- struct{}{}
}
func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
sealer := newSealer(suite.AEAD(key, iv), iv)
switch h.writeEncLevel {
case protocol.EncryptionInitial:
h.writeEncLevel = protocol.EncryptionHandshake
h.handshakeSealer = sealer
h.logger.Debugf("Installed Handshake Write keys")
case protocol.EncryptionHandshake:
h.writeEncLevel = protocol.Encryption1RTT
h.sealer = sealer
h.logger.Debugf("Installed 1-RTT Write keys")
default:
panic("unexpected write encryption level")
}
h.receivedWriteKey <- struct{}{}
}
// WriteRecord is called when TLS writes data
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
switch h.writeEncLevel {
case protocol.EncryptionInitial:
// assume that the first WriteRecord call contains the ClientHello
n, err := h.initialStream.Write(p)
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
h.clientHelloWritten = true
close(h.clientHelloWrittenChan)
}
return n, err
case protocol.EncryptionHandshake:
return h.handshakeStream.Write(p)
default:
return 0, fmt.Errorf("unexpected write encryption level: %s", h.writeEncLevel)
}
}
func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
if h.sealer != nil {
return protocol.Encryption1RTT, h.sealer
}
if h.handshakeSealer != nil {
return protocol.EncryptionHandshake, h.handshakeSealer
}
return protocol.EncryptionInitial, h.initialAEAD
}
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String())
switch level {
case protocol.EncryptionInitial:
return h.initialAEAD, nil
case protocol.EncryptionHandshake:
if h.handshakeSealer == nil {
return nil, errNoSealer
}
return h.handshakeSealer, nil
case protocol.Encryption1RTT:
if h.sealer == nil {
return nil, errNoSealer
}
return h.sealer, nil
default:
return nil, errNoSealer
}
}
func (h *cryptoSetup) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
return h.initialAEAD.Open(dst, src, pn, ad)
}
func (h *cryptoSetup) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
if h.handshakeOpener == nil {
return nil, errors.New("no handshake opener")
}
return h.handshakeOpener.Open(dst, src, pn, ad)
}
func (h *cryptoSetup) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
if h.opener == nil {
return nil, errors.New("no 1-RTT opener")
}
return h.opener.Open(dst, src, pn, ad)
}
func (h *cryptoSetup) ConnectionState() ConnectionState {
// TODO: return the connection state
return ConnectionState{}
}

View File

@ -1,543 +0,0 @@
package handshake
import (
"bytes"
"crypto/rand"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type cryptoSetupClient struct {
mutex sync.RWMutex
hostname string
connID protocol.ConnectionID
version protocol.VersionNumber
initialVersion protocol.VersionNumber
negotiatedVersions []protocol.VersionNumber
cryptoStream io.ReadWriter
serverConfig *serverConfigClient
stk []byte
sno []byte
nonc []byte
proof []byte
chloForSignature []byte
lastSentCHLO []byte
certManager crypto.CertManager
divNonceChan chan struct{}
diversificationNonce []byte
clientHelloCounter int
serverVerified bool // has the certificate chain and the proof already been verified
keyDerivation QuicCryptoKeyDerivationFunction
receivedSecurePacket bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
params *TransportParameters
logger utils.Logger
}
var _ CryptoSetup = &cryptoSetupClient{}
var (
errNoObitForClientNonce = errors.New("CryptoSetup BUG: No OBIT for client nonce available")
errClientNonceAlreadyExists = errors.New("CryptoSetup BUG: A client nonce was already generated")
errConflictingDiversificationNonces = errors.New("Received two different diversification nonces")
)
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
func NewCryptoSetupClient(
cryptoStream io.ReadWriter,
hostname string,
connID protocol.ConnectionID,
version protocol.VersionNumber,
tlsConfig *tls.Config,
params *TransportParameters,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber,
logger utils.Logger,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
}
divNonceChan := make(chan struct{})
cs := &cryptoSetupClient{
cryptoStream: cryptoStream,
hostname: hostname,
connID: connID,
version: version,
certManager: crypto.NewCertManager(tlsConfig),
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
initialVersion: initialVersion,
// The server might have sent greased versions in the Version Negotiation packet.
// We need strip those from the list, since they won't be included in the handshake tag.
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
divNonceChan: divNonceChan,
logger: logger,
}
return cs, nil
}
func (h *cryptoSetupClient) HandleCryptoStream() error {
messageChan := make(chan HandshakeMessage)
errorChan := make(chan error, 1)
go func() {
for {
message, err := ParseHandshakeMessage(h.cryptoStream)
if err != nil {
errorChan <- qerr.Error(qerr.HandshakeFailed, err.Error())
return
}
messageChan <- message
}
}()
for {
if err := h.maybeUpgradeCrypto(); err != nil {
return err
}
h.mutex.RLock()
sendCHLO := h.secureAEAD == nil
h.mutex.RUnlock()
if sendCHLO {
if err := h.sendCHLO(); err != nil {
return err
}
}
var message HandshakeMessage
select {
case <-h.divNonceChan:
// there's no message to process, but we should try upgrading the crypto again
continue
case message = <-messageChan:
case err := <-errorChan:
return err
}
h.logger.Debugf("Got %s", message)
switch message.Tag {
case TagREJ:
if err := h.handleREJMessage(message.Data); err != nil {
return err
}
case TagSHLO:
params, err := h.handleSHLOMessage(message.Data)
if err != nil {
return err
}
// blocks until the session has received the parameters
h.paramsChan <- *params
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
default:
return qerr.InvalidCryptoMessageType
}
}
}
func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
var err error
if stk, ok := cryptoData[TagSTK]; ok {
h.stk = stk
}
if sno, ok := cryptoData[TagSNO]; ok {
h.sno = sno
}
// TODO: what happens if the server sends a different server config in two packets?
if scfg, ok := cryptoData[TagSCFG]; ok {
h.serverConfig, err = parseServerConfig(scfg)
if err != nil {
return err
}
if h.serverConfig.IsExpired() {
return qerr.CryptoServerConfigExpired
}
// now that we have a server config, we can use its OBIT value to generate a client nonce
if len(h.nonc) == 0 {
err = h.generateClientNonce()
if err != nil {
return err
}
}
}
if proof, ok := cryptoData[TagPROF]; ok {
h.proof = proof
h.chloForSignature = h.lastSentCHLO
}
if crt, ok := cryptoData[TagCERT]; ok {
err := h.certManager.SetData(crt)
if err != nil {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
}
err = h.certManager.Verify(h.hostname)
if err != nil {
h.logger.Infof("Certificate validation failed: %s", err.Error())
return qerr.ProofInvalid
}
}
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
if !validProof {
h.logger.Infof("Server proof verification failed")
return qerr.ProofInvalid
}
h.serverVerified = true
}
return nil
}
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.receivedSecurePacket {
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
}
if sno, ok := cryptoData[TagSNO]; ok {
h.sno = sno
}
serverPubs, ok := cryptoData[TagPUBS]
if !ok {
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
verTag, ok := cryptoData[TagVER]
if !ok {
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
}
if !h.validateVersionList(verTag) {
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
nonce := append(h.nonc, h.sno...)
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
if err != nil {
return nil, err
}
leafCert := h.certManager.GetLeafCert()
h.forwardSecureAEAD, err = h.keyDerivation(
true,
ephermalSharedSecret,
nonce,
h.connID,
h.lastSentCHLO,
h.serverConfig.Get(),
leafCert,
nil,
protocol.PerspectiveClient,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
params, err := readHelloMap(cryptoData)
if err != nil {
return nil, qerr.InvalidCryptoMessageParameter
}
return params, nil
}
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
numNegotiatedVersions := len(h.negotiatedVersions)
if numNegotiatedVersions == 0 {
return true
}
if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions {
return false
}
b := bytes.NewReader(verTags)
for i := 0; i < numNegotiatedVersions; i++ {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil { // should never occur, since the length was already checked
return false
}
if protocol.VersionNumber(v) != h.negotiatedVersions[i] {
return false
}
}
return true
}
func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
return data, protocol.EncryptionForwardSecure, nil
}
return nil, protocol.EncryptionUnspecified, err
}
if h.secureAEAD != nil {
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true
return data, protocol.EncryptionSecure, nil
}
if h.receivedSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return nil, protocol.EncryptionUnspecified, err
}
return res, protocol.EncryptionUnencrypted, nil
}
func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
} else if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.secureAEAD
} else {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
}
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no secureAEAD")
}
return h.secureAEAD, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
}
return h.forwardSecureAEAD, nil
}
return nil, errors.New("CryptoSetupClient: no encryption level specified")
}
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
HandshakeComplete: h.forwardSecureAEAD != nil,
PeerCertificates: h.certManager.GetChain(),
}
}
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
h.mutex.Lock()
if len(h.diversificationNonce) > 0 {
defer h.mutex.Unlock()
if !bytes.Equal(h.diversificationNonce, divNonce) {
return errConflictingDiversificationNonces
}
return nil
}
h.diversificationNonce = divNonce
h.mutex.Unlock()
h.divNonceChan <- struct{}{}
return nil
}
func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos {
return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos))
}
b := &bytes.Buffer{}
tags, err := h.getTags()
if err != nil {
return err
}
h.addPadding(tags)
message := HandshakeMessage{
Tag: TagCHLO,
Data: tags,
}
h.logger.Debugf("Sending %s", message)
message.Write(b)
_, err = h.cryptoStream.Write(b.Bytes())
if err != nil {
return err
}
h.lastSentCHLO = b.Bytes()
return nil
}
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
tags := h.params.getHelloMap()
tags[TagSNI] = []byte(h.hostname)
tags[TagPDMD] = []byte("X509")
ccs := h.certManager.GetCommonCertificateHashes()
if len(ccs) > 0 {
tags[TagCCS] = ccs
}
versionTag := make([]byte, 4)
binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion))
tags[TagVER] = versionTag
if len(h.stk) > 0 {
tags[TagSTK] = h.stk
}
if len(h.sno) > 0 {
tags[TagSNO] = h.sno
}
if h.serverConfig != nil {
tags[TagSCID] = h.serverConfig.ID
leafCert := h.certManager.GetLeafCert()
if leafCert != nil {
certHash, _ := h.certManager.GetLeafCertHash()
xlct := make([]byte, 8)
binary.LittleEndian.PutUint64(xlct, certHash)
tags[TagNONC] = h.nonc
tags[TagXLCT] = xlct
tags[TagKEXS] = []byte("C255")
tags[TagAEAD] = []byte("AESG")
tags[TagPUBS] = h.serverConfig.kex.PublicKey() // TODO: check if 3 bytes need to be prepended
}
}
return tags, nil
}
// add a TagPAD to a tagMap, such that the total size will be bigger than the ClientHelloMinimumSize
func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
var size int
for _, tag := range tags {
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
}
paddingSize := protocol.MinClientHelloSize - size
if paddingSize > 0 {
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
}
}
func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
if !h.serverVerified {
return nil
}
h.mutex.Lock()
defer h.mutex.Unlock()
leafCert := h.certManager.GetLeafCert()
if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) {
var err error
var nonce []byte
if h.sno == nil {
nonce = h.nonc
} else {
nonce = append(h.nonc, h.sno...)
}
h.secureAEAD, err = h.keyDerivation(
false,
h.serverConfig.sharedSecret,
nonce,
h.connID,
h.lastSentCHLO,
h.serverConfig.Get(),
leafCert,
h.diversificationNonce,
protocol.PerspectiveClient,
)
if err != nil {
return err
}
h.logger.Debugf("Creating AEAD for secure encryption.")
h.handshakeEvent <- struct{}{}
}
return nil
}
func (h *cryptoSetupClient) generateClientNonce() error {
if len(h.nonc) > 0 {
return errClientNonceAlreadyExists
}
nonc := make([]byte, 32)
binary.BigEndian.PutUint32(nonc, uint32(time.Now().Unix()))
if len(h.serverConfig.obit) != 8 {
return errNoObitForClientNonce
}
copy(nonc[4:12], h.serverConfig.obit)
_, err := rand.Read(nonc[12:])
if err != nil {
return err
}
h.nonc = nonc
return nil
}

View File

@ -1,467 +0,0 @@
package handshake
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// QuicCryptoKeyDerivationFunction is used for key derivation
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
// KeyExchangeFunction is used to make a new KEX
type KeyExchangeFunction func() (crypto.KeyExchange, error)
// The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct {
mutex sync.RWMutex
connID protocol.ConnectionID
remoteAddr net.Addr
scfg *ServerConfig
diversificationNonce []byte
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
acceptSTKCallback func(net.Addr, *Cookie) bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
receivedForwardSecurePacket bool
receivedSecurePacket bool
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
receivedParams bool
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction
cryptoStream io.ReadWriter
params *TransportParameters
sni string // need to fill out the ConnectionState
logger utils.Logger
}
var _ CryptoSetup = &cryptoSetupServer{}
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
// NewCryptoSetup creates a new CryptoSetup instance for a server
func NewCryptoSetup(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
remoteAddr net.Addr,
version protocol.VersionNumber,
divNonce []byte,
scfg *ServerConfig,
params *TransportParameters,
supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
logger utils.Logger,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
return &cryptoSetupServer{
cryptoStream: cryptoStream,
connID: connID,
remoteAddr: remoteAddr,
version: version,
supportedVersions: supportedVersions,
diversificationNonce: divNonce,
scfg: scfg,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD,
params: params,
acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}),
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
logger: logger,
}, nil
}
// HandleCryptoStream reads and writes messages on the crypto stream
func (h *cryptoSetupServer) HandleCryptoStream() error {
for {
var chloData bytes.Buffer
message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))
if err != nil {
return qerr.HandshakeFailed
}
if message.Tag != TagCHLO {
return qerr.InvalidCryptoMessageType
}
h.logger.Debugf("Got %s", message)
done, err := h.handleMessage(chloData.Bytes(), message.Data)
if err != nil {
return err
}
if done {
return nil
}
}
}
func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) {
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
return false, ErrNSTPExperiment
}
sniSlice, ok := cryptoData[TagSNI]
if !ok {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
}
sni := string(sniSlice)
if sni == "" {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
}
h.sni = sni
// prevent version downgrade attacks
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
verSlice, ok := cryptoData[TagVER]
if !ok {
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")
}
if len(verSlice) != 4 {
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")
}
ver := protocol.VersionNumber(binary.BigEndian.Uint32(verSlice))
// If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack.
if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) {
return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
var reply []byte
var err error
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
if err != nil {
return false, err
}
params, err := readHelloMap(cryptoData)
if err != nil {
return false, err
}
// blocks until the session has received the parameters
if !h.receivedParams {
h.receivedParams = true
h.paramsChan <- *params
}
if !h.isInchoateCHLO(cryptoData, certUncompressed) {
// We have a CHLO with a proper server config ID, do a 0-RTT handshake
reply, err = h.handleCHLO(sni, chloData, cryptoData)
if err != nil {
return false, err
}
if _, err := h.cryptoStream.Write(reply); err != nil {
return false, err
}
h.handshakeEvent <- struct{}{}
close(h.sentSHLO)
return true, nil
}
// We have an inchoate or non-matching CHLO, we now send a rejection
reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData)
if err != nil {
return false, err
}
_, err = h.cryptoStream.Write(reply)
return false, err
}
// Open a message
func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
h.receivedForwardSecurePacket = true
// wait for the send on the handshakeEvent chan
<-h.sentSHLO
close(h.handshakeEvent)
}
return res, protocol.EncryptionForwardSecure, nil
}
if h.receivedForwardSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
if h.secureAEAD != nil {
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true
return res, protocol.EncryptionSecure, nil
}
if h.receivedSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return res, protocol.EncryptionUnspecified, err
}
return res, protocol.EncryptionUnencrypted, err
}
func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.secureAEAD
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no secureAEAD")
}
return h.secureAEAD, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD")
}
return h.forwardSecureAEAD, nil
}
return nil, errors.New("CryptoSetupServer: no encryption level specified")
}
func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool {
if _, ok := cryptoData[TagPUBS]; !ok {
return true
}
scid, ok := cryptoData[TagSCID]
if !ok || !bytes.Equal(h.scfg.ID, scid) {
return true
}
xlctTag, ok := cryptoData[TagXLCT]
if !ok || len(xlctTag) != 8 {
return true
}
xlct := binary.LittleEndian.Uint64(xlctTag)
if crypto.HashCert(cert) != xlct {
return true
}
return !h.acceptSTK(cryptoData[TagSTK])
}
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
if err != nil {
h.logger.Debugf("STK invalid: %s", err.Error())
return false
}
return h.acceptSTKCallback(h.remoteAddr, stk)
}
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
if err != nil {
return nil, err
}
replyMap := map[Tag][]byte{
TagSCFG: h.scfg.Get(),
TagSTK: token,
TagSVID: []byte("quic-go"),
}
if h.acceptSTK(cryptoData[TagSTK]) {
proof, err := h.scfg.Sign(sni, chlo)
if err != nil {
return nil, err
}
commonSetHashes := cryptoData[TagCCS]
cachedCertsHashes := cryptoData[TagCCRT]
certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes)
if err != nil {
return nil, err
}
// Token was valid, send more details
replyMap[TagPROF] = proof
replyMap[TagCERT] = certCompressed
}
message := HandshakeMessage{
Tag: TagREJ,
Data: replyMap,
}
var serverReply bytes.Buffer
message.Write(&serverReply)
h.logger.Debugf("Sending %s", message)
return serverReply.Bytes(), nil
}
func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) {
// We have a CHLO matching our server config, we can continue with the 0-RTT handshake
sharedSecret, err := h.scfg.kex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil {
return nil, err
}
h.mutex.Lock()
defer h.mutex.Unlock()
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
if err != nil {
return nil, err
}
serverNonce := make([]byte, 32)
if _, err = rand.Read(serverNonce); err != nil {
return nil, err
}
clientNonce := cryptoData[TagNONC]
err = h.validateClientNonce(clientNonce)
if err != nil {
return nil, err
}
aead := cryptoData[TagAEAD]
if !bytes.Equal(aead, []byte("AESG")) {
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
}
kexs := cryptoData[TagKEXS]
if !bytes.Equal(kexs, []byte("C255")) {
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
}
h.secureAEAD, err = h.keyDerivation(
false,
sharedSecret,
clientNonce,
h.connID,
data,
h.scfg.Get(),
certUncompressed,
h.diversificationNonce,
protocol.PerspectiveServer,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for secure encryption.")
h.handshakeEvent <- struct{}{}
// Generate a new curve instance to derive the forward secure key
var fsNonce bytes.Buffer
fsNonce.Write(clientNonce)
fsNonce.Write(serverNonce)
ephermalKex, err := h.keyExchange()
if err != nil {
return nil, err
}
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil {
return nil, err
}
h.forwardSecureAEAD, err = h.keyDerivation(
true,
ephermalSharedSecret,
fsNonce.Bytes(),
h.connID,
data,
h.scfg.Get(),
certUncompressed,
nil,
protocol.PerspectiveServer,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
replyMap := h.params.getHelloMap()
// add crypto parameters
verTag := &bytes.Buffer{}
for _, v := range h.supportedVersions {
utils.BigEndian.WriteUint32(verTag, uint32(v))
}
replyMap[TagPUBS] = ephermalKex.PublicKey()
replyMap[TagSNO] = serverNonce
replyMap[TagVER] = verTag.Bytes()
// note that the SHLO *has* to fit into one packet
message := HandshakeMessage{
Tag: TagSHLO,
Data: replyMap,
}
var reply bytes.Buffer
message.Write(&reply)
h.logger.Debugf("Sending %s", message)
return reply.Bytes(), nil
}
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
ServerName: h.sni,
HandshakeComplete: h.receivedForwardSecurePacket,
}
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
}
if !bytes.Equal(nonce[4:12], h.scfg.obit) {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")
}
return nil
}

View File

@ -1,163 +0,0 @@
package handshake
import (
"errors"
"fmt"
"io"
"sync"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// KeyDerivationFunction is used for key derivation
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
type cryptoSetupTLS struct {
mutex sync.RWMutex
perspective protocol.Perspective
keyDerivation KeyDerivationFunction
nullAEAD crypto.AEAD
aead crypto.AEAD
tls mintTLS
conn *cryptoStreamConn
handshakeEvent chan<- struct{}
}
var _ CryptoSetupTLS = &cryptoSetupTLS{}
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
config *mint.Config,
handshakeEvent chan<- struct{},
version protocol.VersionNumber,
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
conn := newCryptoStreamConn(cryptoStream)
tls := mint.Server(conn, config)
return &cryptoSetupTLS{
tls: tls,
conn: conn,
nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
}, nil
}
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
config *mint.Config,
handshakeEvent chan<- struct{},
version protocol.VersionNumber,
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
}
conn := newCryptoStreamConn(cryptoStream)
tls := mint.Client(conn, config)
return &cryptoSetupTLS{
tls: tls,
conn: conn,
perspective: protocol.PerspectiveClient,
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
}, nil
}
func (h *cryptoSetupTLS) HandleCryptoStream() error {
for {
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
}
state := h.tls.ConnectionState().HandshakeState
if err := h.conn.Flush(); err != nil {
return err
}
if state == mint.StateClientConnected || state == mint.StateServerConnected {
break
}
}
aead, err := h.keyDerivation(h.tls, h.perspective)
if err != nil {
return err
}
h.mutex.Lock()
h.aead = aead
h.mutex.Unlock()
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
return nil
}
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.aead == nil {
return nil, errors.New("no 1-RTT sealer")
}
return h.aead.Open(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.aead != nil {
return protocol.EncryptionForwardSecure, h.aead
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", encLevel.String())
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionForwardSecure:
if h.aead == nil {
return nil, errNoSealer
}
return h.aead, nil
default:
return nil, errNoSealer
}
}
func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
mintConnState := h.tls.ConnectionState()
return ConnectionState{
// TODO: set the ServerName, once mint exports it
HandshakeComplete: h.aead != nil,
PeerCertificates: mintConnState.PeerCertificates,
}
}

View File

@ -1,69 +0,0 @@
package handshake
import (
"bytes"
"io"
"net"
"time"
)
type cryptoStreamConn struct {
buffer *bytes.Buffer
stream io.ReadWriter
}
var _ net.Conn = &cryptoStreamConn{}
func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn {
return &cryptoStreamConn{
stream: stream,
buffer: &bytes.Buffer{},
}
}
func (c *cryptoStreamConn) Read(b []byte) (int, error) {
return c.stream.Read(b)
}
func (c *cryptoStreamConn) Write(p []byte) (int, error) {
return c.buffer.Write(p)
}
func (c *cryptoStreamConn) Flush() error {
if c.buffer.Len() == 0 {
return nil
}
_, err := c.stream.Write(c.buffer.Bytes())
c.buffer.Reset()
return err
}
// Close is not implemented
func (c *cryptoStreamConn) Close() error {
return nil
}
// LocalAddr is not implemented
func (c *cryptoStreamConn) LocalAddr() net.Addr {
return nil
}
// RemoteAddr is not implemented
func (c *cryptoStreamConn) RemoteAddr() net.Addr {
return nil
}
// SetReadDeadline is not implemented
func (c *cryptoStreamConn) SetReadDeadline(time.Time) error {
return nil
}
// SetWriteDeadline is not implemented
func (c *cryptoStreamConn) SetWriteDeadline(time.Time) error {
return nil
}
// SetDeadline is not implemented
func (c *cryptoStreamConn) SetDeadline(time.Time) error {
return nil
}

View File

@ -1,48 +0,0 @@
package handshake
import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
kexLifetime = protocol.EphermalKeyLifetime
kexCurrent crypto.KeyExchange
kexCurrentTime time.Time
kexMutex sync.RWMutex
)
// getEphermalKEX returns the currently active KEX, which changes every protocol.EphermalKeyLifetime
// See the explanation from the QUIC crypto doc:
//
// A single connection is the usual scope for forward security, but the security
// difference between an ephemeral key used for a single connection, and one
// used for all connections for 60 seconds is negligible. Thus we can amortise
// the Diffie-Hellman key generation at the server over all the connections in a
// small time span.
func getEphermalKEX() (crypto.KeyExchange, error) {
kexMutex.RLock()
res := kexCurrent
t := kexCurrentTime
kexMutex.RUnlock()
if res != nil && time.Since(t) < kexLifetime {
return res, nil
}
kexMutex.Lock()
defer kexMutex.Unlock()
// Check if still unfulfilled
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
kex, err := crypto.NewCurve25519KEX()
if err != nil {
return nil, err
}
kexCurrent = kex
kexCurrentTime = time.Now()
return kexCurrent, nil
}
return kexCurrent, nil
}

View File

@ -1,137 +0,0 @@
package handshake
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"sort"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// A HandshakeMessage is a handshake message
type HandshakeMessage struct {
Tag Tag
Data map[Tag][]byte
}
var _ fmt.Stringer = &HandshakeMessage{}
// ParseHandshakeMessage reads a crypto message
func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) {
slice4 := make([]byte, 4)
if _, err := io.ReadFull(r, slice4); err != nil {
return HandshakeMessage{}, err
}
messageTag := Tag(binary.LittleEndian.Uint32(slice4))
if _, err := io.ReadFull(r, slice4); err != nil {
return HandshakeMessage{}, err
}
nPairs := binary.LittleEndian.Uint32(slice4)
if nPairs > protocol.CryptoMaxParams {
return HandshakeMessage{}, qerr.CryptoTooManyEntries
}
index := make([]byte, nPairs*8)
if _, err := io.ReadFull(r, index); err != nil {
return HandshakeMessage{}, err
}
resultMap := map[Tag][]byte{}
var dataStart uint32
for indexPos := 0; indexPos < int(nPairs)*8; indexPos += 8 {
tag := Tag(binary.LittleEndian.Uint32(index[indexPos : indexPos+4]))
dataEnd := binary.LittleEndian.Uint32(index[indexPos+4 : indexPos+8])
dataLen := dataEnd - dataStart
if dataLen > protocol.CryptoParameterMaxLength {
return HandshakeMessage{}, qerr.Error(qerr.CryptoInvalidValueLength, "value too long")
}
data := make([]byte, dataLen)
if _, err := io.ReadFull(r, data); err != nil {
return HandshakeMessage{}, err
}
resultMap[tag] = data
dataStart = dataEnd
}
return HandshakeMessage{
Tag: messageTag,
Data: resultMap}, nil
}
// Write writes a crypto message
func (h HandshakeMessage) Write(b *bytes.Buffer) {
data := h.Data
utils.LittleEndian.WriteUint32(b, uint32(h.Tag))
utils.LittleEndian.WriteUint16(b, uint16(len(data)))
utils.LittleEndian.WriteUint16(b, 0)
// Save current position in the buffer, so that we can update the index in-place later
indexStart := b.Len()
indexData := make([]byte, 8*len(data))
b.Write(indexData) // Will be updated later
offset := uint32(0)
for i, t := range h.getTagsSorted() {
v := data[t]
b.Write(v)
offset += uint32(len(v))
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
binary.LittleEndian.PutUint32(indexData[i*8+4:], offset)
}
// Now we write the index data for real
copy(b.Bytes()[indexStart:], indexData)
}
func (h *HandshakeMessage) getTagsSorted() []Tag {
tags := make([]Tag, len(h.Data))
i := 0
for t := range h.Data {
tags[i] = t
i++
}
sort.Slice(tags, func(i, j int) bool {
return tags[i] < tags[j]
})
return tags
}
func (h HandshakeMessage) String() string {
var pad string
res := tagToString(h.Tag) + ":\n"
for _, tag := range h.getTagsSorted() {
if tag == TagPAD {
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
} else {
res += fmt.Sprintf("\t%s: %#v\n", tagToString(tag), string(h.Data[tag]))
}
}
if len(pad) > 0 {
res += pad
}
return res
}
func tagToString(tag Tag) string {
b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, uint32(tag))
for i := range b {
if b[i] == 0 {
b[i] = ' '
}
}
return string(b)
}

View File

@ -2,54 +2,43 @@ package handshake
import (
"crypto/x509"
"io"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
)
// Opener opens a packet
type Opener interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
// Sealer seals a packet
type Sealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
Overhead() int
}
// mintTLS combines some methods needed to interact with mint.
type mintTLS interface {
crypto.TLSExporter
Handshake() mint.Alert
// A tlsExtensionHandler sends and received the QUIC TLS extension.
type tlsExtensionHandler interface {
GetExtensions(msgType uint8) []qtls.Extension
ReceivedExtensions(msgType uint8, exts []qtls.Extension) error
}
// A TLSExtensionHandler sends and received the QUIC TLS extension.
// It provides the parameters sent by the peer on a channel.
type TLSExtensionHandler interface {
Send(mint.HandshakeType, *mint.ExtensionList) error
Receive(mint.HandshakeType, *mint.ExtensionList) error
GetPeerParams() <-chan TransportParameters
}
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
RunHandshake() error
io.Closer
type baseCryptoSetup interface {
HandleCryptoStream() error
HandleMessage([]byte, protocol.EncryptionLevel) bool
ConnectionState() ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
}
// CryptoSetup is the crypto setup used by gQUIC
type CryptoSetup interface {
baseCryptoSetup
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
}
// CryptoSetupTLS is the crypto setup used by IETF QUIC
type CryptoSetupTLS interface {
baseCryptoSetup
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
}
// ConnectionState records basic details about the QUIC connection.

View File

@ -1,3 +0,0 @@
package handshake
//go:generate sh -c "../mockgen_internal.sh handshake mock_mint_tls_test.go github.com/lucas-clemente/quic-go/internal/handshake mintTLS"

View File

@ -0,0 +1,48 @@
package handshake
import (
"crypto/tls"
"github.com/marten-seemann/qtls"
)
func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
if c == nil {
c = &tls.Config{}
}
// QUIC requires TLS 1.3 or newer
if c.MinVersion < qtls.VersionTLS13 {
c.MinVersion = qtls.VersionTLS13
}
if c.MaxVersion < qtls.VersionTLS13 {
c.MaxVersion = qtls.VersionTLS13
}
return &qtls.Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
// TODO: make GetCertificate work
// GetCertificate: c.GetCertificate,
GetClientCertificate: c.GetClientCertificate,
// TODO: make GetConfigForClient work
// GetConfigForClient: c.GetConfigForClient,
VerifyPeerCertificate: c.VerifyPeerCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
KeyLogWriter: c.KeyLogWriter,
}
}

View File

@ -1,73 +0,0 @@
package handshake
import (
"bytes"
"crypto/rand"
"github.com/lucas-clemente/quic-go/internal/crypto"
)
// ServerConfig is a server config
type ServerConfig struct {
kex crypto.KeyExchange
certChain crypto.CertChain
ID []byte
obit []byte
cookieGenerator *CookieGenerator
}
// NewServerConfig creates a new server config
func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*ServerConfig, error) {
id := make([]byte, 16)
_, err := rand.Read(id)
if err != nil {
return nil, err
}
obit := make([]byte, 8)
if _, err = rand.Read(obit); err != nil {
return nil, err
}
cookieGenerator, err := NewCookieGenerator()
if err != nil {
return nil, err
}
return &ServerConfig{
kex: kex,
certChain: certChain,
ID: id,
obit: obit,
cookieGenerator: cookieGenerator,
}, nil
}
// Get the server config binary representation
func (s *ServerConfig) Get() []byte {
var serverConfig bytes.Buffer
msg := HandshakeMessage{
Tag: TagSCFG,
Data: map[Tag][]byte{
TagSCID: s.ID,
TagKEXS: []byte("C255"),
TagAEAD: []byte("AESG"),
TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...),
TagOBIT: s.obit,
TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
},
}
msg.Write(&serverConfig)
return serverConfig.Bytes()
}
// Sign the server config and CHLO with the server's keyData
func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) {
return s.certChain.SignServerProof(sni, chlo, s.Get())
}
// GetCertsCompressed returns the certificate data
func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) {
return s.certChain.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
}

View File

@ -1,184 +0,0 @@
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type serverConfigClient struct {
raw []byte
ID []byte
obit []byte
expiry time.Time
kex crypto.KeyExchange
sharedSecret []byte
}
var (
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG")
)
// parseServerConfig parses a server config
func parseServerConfig(data []byte) (*serverConfigClient, error) {
message, err := ParseHandshakeMessage(bytes.NewReader(data))
if err != nil {
return nil, err
}
if message.Tag != TagSCFG {
return nil, errMessageNotServerConfig
}
scfg := &serverConfigClient{raw: data}
err = scfg.parseValues(message.Data)
if err != nil {
return nil, err
}
return scfg, nil
}
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
// SCID
scfgID, ok := tagMap[TagSCID]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID")
}
if len(scfgID) != 16 {
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID")
}
s.ID = scfgID
// KEXS
// TODO: setup Key Exchange
kexs, ok := tagMap[TagKEXS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS")
}
if len(kexs)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS")
}
c255Foundat := -1
for i := 0; i < len(kexs)/4; i++ {
if bytes.Equal(kexs[4*i:4*i+4], []byte("C255")) {
c255Foundat = i
break
}
}
if c255Foundat < 0 {
return qerr.Error(qerr.CryptoNoSupport, "KEXS: Could not find C255, other key exchanges are not supported")
}
// AEAD
aead, ok := tagMap[TagAEAD]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD")
}
if len(aead)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD")
}
var aesgFound bool
for i := 0; i < len(aead)/4; i++ {
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) {
aesgFound = true
break
}
}
if !aesgFound {
return qerr.Error(qerr.CryptoNoSupport, "AEAD")
}
// PUBS
pubs, ok := tagMap[TagPUBS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
var pubsKexs []struct {
Length uint32
Value []byte
}
var lastLen uint32
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
// the PUBS value is always prepended by 3 byte little endian length field
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
if err != nil {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
}
if lastLen == 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
if i+3+int(lastLen) > len(pubs) {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
pubsKexs = append(pubsKexs, struct {
Length uint32
Value []byte
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
}
if c255Foundat >= len(pubsKexs) {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
}
if pubsKexs[c255Foundat].Length != 32 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
var err error
s.kex, err = crypto.NewCurve25519KEX()
if err != nil {
return err
}
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
if err != nil {
return err
}
// OBIT
obit, ok := tagMap[TagOBIT]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT")
}
if len(obit) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT")
}
s.obit = obit
// EXPY
expy, ok := tagMap[TagEXPY]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY")
}
if len(expy) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY")
}
// make sure that the value doesn't overflow an int64
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2)
s.expiry = time.Unix(int64(expyTimestamp), 0)
// TODO: implement VER
return nil
}
func (s *serverConfigClient) IsExpired() bool {
return s.expiry.Before(time.Now())
}
func (s *serverConfigClient) Get() []byte {
return s.raw
}

View File

@ -1,93 +0,0 @@
package handshake
// A Tag in the QUIC crypto
type Tag uint32
const (
// TagCHLO is a client hello
TagCHLO Tag = 'C' + 'H'<<8 + 'L'<<16 + 'O'<<24
// TagREJ is a server hello rejection
TagREJ Tag = 'R' + 'E'<<8 + 'J'<<16
// TagSCFG is a server config
TagSCFG Tag = 'S' + 'C'<<8 + 'F'<<16 + 'G'<<24
// TagPAD is padding
TagPAD Tag = 'P' + 'A'<<8 + 'D'<<16
// TagSNI is the server name indication
TagSNI Tag = 'S' + 'N'<<8 + 'I'<<16
// TagVER is the QUIC version
TagVER Tag = 'V' + 'E'<<8 + 'R'<<16
// TagCCS are the hashes of the common certificate sets
TagCCS Tag = 'C' + 'C'<<8 + 'S'<<16
// TagCCRT are the hashes of the cached certificates
TagCCRT Tag = 'C' + 'C'<<8 + 'R'<<16 + 'T'<<24
// TagMSPC is max streams per connection
TagMSPC Tag = 'M' + 'S'<<8 + 'P'<<16 + 'C'<<24
// TagMIDS is max incoming dyanamic streams
TagMIDS Tag = 'M' + 'I'<<8 + 'D'<<16 + 'S'<<24
// TagUAID is the user agent ID
TagUAID Tag = 'U' + 'A'<<8 + 'I'<<16 + 'D'<<24
// TagSVID is the server ID (unofficial tag by us :)
TagSVID Tag = 'S' + 'V'<<8 + 'I'<<16 + 'D'<<24
// TagTCID is truncation of the connection ID
TagTCID Tag = 'T' + 'C'<<8 + 'I'<<16 + 'D'<<24
// TagPDMD is the proof demand
TagPDMD Tag = 'P' + 'D'<<8 + 'M'<<16 + 'D'<<24
// TagSRBF is the socket receive buffer
TagSRBF Tag = 'S' + 'R'<<8 + 'B'<<16 + 'F'<<24
// TagICSL is the idle connection state lifetime
TagICSL Tag = 'I' + 'C'<<8 + 'S'<<16 + 'L'<<24
// TagNONP is the client proof nonce
TagNONP Tag = 'N' + 'O'<<8 + 'N'<<16 + 'P'<<24
// TagSCLS is the silently close timeout
TagSCLS Tag = 'S' + 'C'<<8 + 'L'<<16 + 'S'<<24
// TagCSCT is the signed cert timestamp (RFC6962) of leaf cert
TagCSCT Tag = 'C' + 'S'<<8 + 'C'<<16 + 'T'<<24
// TagCOPT are the connection options
TagCOPT Tag = 'C' + 'O'<<8 + 'P'<<16 + 'T'<<24
// TagCFCW is the initial session/connection flow control receive window
TagCFCW Tag = 'C' + 'F'<<8 + 'C'<<16 + 'W'<<24
// TagSFCW is the initial stream flow control receive window.
TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24
// TagNSTP is the no STOP_WAITING experiment
// currently unsupported by quic-go
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24
// TagSTK is the source-address token
TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16
// TagSNO is the server nonce
TagSNO Tag = 'S' + 'N'<<8 + 'O'<<16
// TagPROF is the server proof
TagPROF Tag = 'P' + 'R'<<8 + 'O'<<16 + 'F'<<24
// TagNONC is the client nonce
TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24
// TagXLCT is the expected leaf certificate
TagXLCT Tag = 'X' + 'L'<<8 + 'C'<<16 + 'T'<<24
// TagSCID is the server config ID
TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24
// TagKEXS is the list of key exchange algos
TagKEXS Tag = 'K' + 'E'<<8 + 'X'<<16 + 'S'<<24
// TagAEAD is the list of AEAD algos
TagAEAD Tag = 'A' + 'E'<<8 + 'A'<<16 + 'D'<<24
// TagPUBS is the public value for the KEX
TagPUBS Tag = 'P' + 'U'<<8 + 'B'<<16 + 'S'<<24
// TagOBIT is the client orbit
TagOBIT Tag = 'O' + 'B'<<8 + 'I'<<16 + 'T'<<24
// TagEXPY is the server config expiry
TagEXPY Tag = 'E' + 'X'<<8 + 'P'<<16 + 'Y'<<24
// TagCERT is the CERT data
TagCERT Tag = 0xff545243
// TagSHLO is the server hello
TagSHLO Tag = 'S' + 'H'<<8 + 'L'<<16 + 'O'<<24
// TagPRST is the public reset tag
TagPRST Tag = 'P' + 'R'<<8 + 'S'<<16 + 'T'<<24
// TagRSEQ is the public reset rejected packet number
TagRSEQ Tag = 'R' + 'S'<<8 + 'E'<<16 + 'Q'<<24
// TagRNON is the public reset nonce
TagRNON Tag = 'R' + 'N'<<8 + 'O'<<16 + 'N'<<24
)

View File

@ -6,26 +6,12 @@ import (
"errors"
"fmt"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type transportParameterID uint16
const quicTLSExtensionType = 0xff5
const (
initialMaxStreamDataParameterID transportParameterID = 0x0
initialMaxDataParameterID transportParameterID = 0x1
initialMaxBidiStreamsParameterID transportParameterID = 0x2
idleTimeoutParameterID transportParameterID = 0x3
maxPacketSizeParameterID transportParameterID = 0x5
statelessResetTokenParameterID transportParameterID = 0x6
initialMaxUniStreamsParameterID transportParameterID = 0x8
disableMigrationParameterID transportParameterID = 0x9
)
type clientHelloTransportParameters struct {
InitialVersion protocol.VersionNumber
Parameters TransportParameters
@ -52,7 +38,7 @@ func (p *clientHelloTransportParameters) Unmarshal(data []byte) error {
if len(data) != paramsLen {
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
}
return p.Parameters.unmarshal(data)
return p.Parameters.unmarshal(data, protocol.PerspectiveClient)
}
type encryptedExtensionsTransportParameters struct {
@ -100,24 +86,5 @@ func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error {
if len(data) != paramsLen {
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
}
return p.Parameters.unmarshal(data)
}
type tlsExtensionBody struct {
data []byte
}
var _ mint.ExtensionBody = &tlsExtensionBody{}
func (e *tlsExtensionBody) Type() mint.ExtensionType {
return quicTLSExtensionType
}
func (e *tlsExtensionBody) Marshal() ([]byte, error) {
return e.data, nil
}
func (e *tlsExtensionBody) Unmarshal(data []byte) (int, error) {
e.data = data
return len(data), nil
return p.Parameters.unmarshal(data, protocol.PerspectiveServer)
}

View File

@ -4,17 +4,17 @@ import (
"errors"
"fmt"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
)
type extensionHandlerClient struct {
ourParams *TransportParameters
paramsChan chan TransportParameters
paramsChan chan<- TransportParameters
origConnID protocol.ConnectionID
initialVersion protocol.VersionNumber
supportedVersions []protocol.VersionNumber
version protocol.VersionNumber
@ -22,17 +22,17 @@ type extensionHandlerClient struct {
logger utils.Logger
}
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
var _ TLSExtensionHandler = &extensionHandlerClient{}
var _ tlsExtensionHandler = &extensionHandlerClient{}
// NewExtensionHandlerClient creates a new extension handler for the client.
func NewExtensionHandlerClient(
// newExtensionHandlerClient creates a new extension handler for the client.
func newExtensionHandlerClient(
params *TransportParameters,
origConnID protocol.ConnectionID,
initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
logger utils.Logger,
) TLSExtensionHandler {
) (tlsExtensionHandler, <-chan TransportParameters) {
// The client reads the transport parameters from the Encrypted Extensions message.
// The paramsChan is used in the session's run loop's select statement.
// We have to use an unbuffered channel here to make sure that the session actually processes the transport parameters immediately.
@ -40,48 +40,48 @@ func NewExtensionHandlerClient(
return &extensionHandlerClient{
ourParams: params,
paramsChan: paramsChan,
origConnID: origConnID,
initialVersion: initialVersion,
supportedVersions: supportedVersions,
version: version,
logger: logger,
}
}, paramsChan
}
func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.ExtensionList) error {
if hType != mint.HandshakeTypeClientHello {
func (h *extensionHandlerClient) GetExtensions(msgType uint8) []qtls.Extension {
if messageType(msgType) != typeClientHello {
return nil
}
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
chtp := &clientHelloTransportParameters{
InitialVersion: h.initialVersion,
Parameters: *h.ourParams,
}
return el.Add(&tlsExtensionBody{data: chtp.Marshal()})
return []qtls.Extension{{
Type: quicTLSExtensionType,
Data: (&clientHelloTransportParameters{
InitialVersion: h.initialVersion,
Parameters: *h.ourParams,
}).Marshal(),
}}
}
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{}
found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeEncryptedExtensions {
if found {
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
}
func (h *extensionHandlerClient) ReceivedExtensions(msgType uint8, exts []qtls.Extension) error {
if messageType(msgType) != typeEncryptedExtensions {
return nil
}
// hType == mint.HandshakeTypeEncryptedExtensions
var found bool
eetp := &encryptedExtensionsTransportParameters{}
for _, ext := range exts {
if ext.Type != quicTLSExtensionType {
continue
}
if err := eetp.Unmarshal(ext.Data); err != nil {
return err
}
found = true
}
if !found {
return errors.New("EncryptedExtensions message didn't contain a QUIC extension")
}
eetp := &encryptedExtensionsTransportParameters{}
if err := eetp.Unmarshal(ext.data); err != nil {
return err
}
// check that the negotiated_version is the current version
if eetp.NegotiatedVersion != h.version {
return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version")
@ -98,15 +98,16 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
}
}
params := eetp.Parameters
// check that the server sent a stateless reset token
if len(eetp.Parameters.StatelessResetToken) == 0 {
if len(params.StatelessResetToken) == 0 {
return errors.New("server didn't sent stateless_reset_token")
}
h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters)
h.paramsChan <- eetp.Parameters
// check the Retry token
if !h.origConnID.Equal(params.OriginalConnectionID) {
return fmt.Errorf("expected original_connection_id to equal %s, is %s", h.origConnID, params.OriginalConnectionID)
}
h.logger.Debugf("Received Transport Parameters: %s", &params)
h.paramsChan <- params
return nil
}
func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters {
return h.paramsChan
}

View File

@ -2,18 +2,16 @@ package handshake
import (
"errors"
"fmt"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
)
type extensionHandlerServer struct {
ourParams *TransportParameters
paramsChan chan TransportParameters
paramsChan chan<- TransportParameters
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
@ -21,62 +19,60 @@ type extensionHandlerServer struct {
logger utils.Logger
}
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
var _ TLSExtensionHandler = &extensionHandlerServer{}
var _ tlsExtensionHandler = &extensionHandlerServer{}
// NewExtensionHandlerServer creates a new extension handler for the server
func NewExtensionHandlerServer(
// newExtensionHandlerServer creates a new extension handler for the server
func newExtensionHandlerServer(
params *TransportParameters,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
logger utils.Logger,
) TLSExtensionHandler {
) (tlsExtensionHandler, <-chan TransportParameters) {
// Processing the ClientHello is performed statelessly (and from a single go-routine).
// Therefore, we have to use a buffered chan to pass the transport parameters to that go routine.
paramsChan := make(chan TransportParameters, 1)
paramsChan := make(chan TransportParameters)
return &extensionHandlerServer{
ourParams: params,
paramsChan: paramsChan,
supportedVersions: supportedVersions,
version: version,
logger: logger,
}
}, paramsChan
}
func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.ExtensionList) error {
if hType != mint.HandshakeTypeEncryptedExtensions {
func (h *extensionHandlerServer) GetExtensions(msgType uint8) []qtls.Extension {
if messageType(msgType) != typeEncryptedExtensions {
return nil
}
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
eetp := &encryptedExtensionsTransportParameters{
NegotiatedVersion: h.version,
SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
Parameters: *h.ourParams,
}
return el.Add(&tlsExtensionBody{data: eetp.Marshal()})
return []qtls.Extension{{
Type: quicTLSExtensionType,
Data: (&encryptedExtensionsTransportParameters{
NegotiatedVersion: h.version,
SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
Parameters: *h.ourParams,
}).Marshal(),
}}
}
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{}
found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeClientHello {
if found {
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
}
func (h *extensionHandlerServer) ReceivedExtensions(msgType uint8, exts []qtls.Extension) error {
if messageType(msgType) != typeClientHello {
return nil
}
var found bool
chtp := &clientHelloTransportParameters{}
for _, ext := range exts {
if ext.Type != quicTLSExtensionType {
continue
}
if err := chtp.Unmarshal(ext.Data); err != nil {
return err
}
found = true
}
if !found {
return errors.New("ClientHello didn't contain a QUIC extension")
}
chtp := &clientHelloTransportParameters{}
if err := chtp.Unmarshal(ext.data); err != nil {
return err
}
// perform the stateless version negotiation validation:
// make sure that we would have sent a Version Negotiation Packet if the client offered the initial version
@ -84,17 +80,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) {
return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version")
}
// check that the client didn't send a stateless reset token
if len(chtp.Parameters.StatelessResetToken) != 0 {
// TODO: return the correct error type
return errors.New("client sent a stateless reset token")
}
h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters)
h.paramsChan <- chtp.Parameters
return nil
}
func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters {
return h.paramsChan
}

View File

@ -2,194 +2,190 @@ package handshake
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sort"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// errMalformedTag is returned when the tag value cannot be read
var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
type transportParameterID uint16
const (
originalConnectionIDParameterID transportParameterID = 0x0
idleTimeoutParameterID transportParameterID = 0x1
statelessResetTokenParameterID transportParameterID = 0x2
maxPacketSizeParameterID transportParameterID = 0x3
initialMaxDataParameterID transportParameterID = 0x4
initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5
initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6
initialMaxStreamDataUniParameterID transportParameterID = 0x7
initialMaxStreamsBidiParameterID transportParameterID = 0x8
initialMaxStreamsUniParameterID transportParameterID = 0x9
disableMigrationParameterID transportParameterID = 0xc
)
// TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct {
StreamFlowControlWindow protocol.ByteCount
ConnectionFlowControlWindow protocol.ByteCount
InitialMaxStreamDataBidiLocal protocol.ByteCount
InitialMaxStreamDataBidiRemote protocol.ByteCount
InitialMaxStreamDataUni protocol.ByteCount
InitialMaxData protocol.ByteCount
MaxPacketSize protocol.ByteCount
MaxUniStreams uint16 // only used for IETF QUIC
MaxBidiStreams uint16 // only used for IETF QUIC
MaxStreams uint32 // only used for gQUIC
MaxUniStreams uint64
MaxBidiStreams uint64
OmitConnectionID bool // only used for gQUIC
IdleTimeout time.Duration
DisableMigration bool // only used for IETF QUIC
StatelessResetToken []byte // only used for IETF QUIC
IdleTimeout time.Duration
DisableMigration bool
StatelessResetToken []byte
OriginalConnectionID protocol.ConnectionID
}
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) {
params := &TransportParameters{}
if value, ok := tags[TagTCID]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective) error {
// needed to check that every parameter is only sent at most once
var parameterIDs []transportParameterID
r := bytes.NewReader(data)
for r.Len() >= 4 {
paramIDInt, _ := utils.BigEndian.ReadUint16(r)
paramID := transportParameterID(paramIDInt)
paramLen, _ := utils.BigEndian.ReadUint16(r)
parameterIDs = append(parameterIDs, paramID)
switch paramID {
case initialMaxStreamDataBidiLocalParameterID,
initialMaxStreamDataBidiRemoteParameterID,
initialMaxStreamDataUniParameterID,
initialMaxDataParameterID,
initialMaxStreamsBidiParameterID,
initialMaxStreamsUniParameterID,
idleTimeoutParameterID,
maxPacketSizeParameterID:
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
return err
}
default:
if r.Len() < int(paramLen) {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen)
}
switch paramID {
case disableMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
}
p.DisableMigration = true
case statelessResetTokenParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a stateless_reset_token")
}
if paramLen != 16 {
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
b := make([]byte, 16)
r.Read(b)
p.StatelessResetToken = b
case originalConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent an original_connection_id")
}
p.OriginalConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
default:
r.Seek(int64(paramLen), io.SeekCurrent)
}
}
params.OmitConnectionID = (v == 0)
}
if value, ok := tags[TagMIDS]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
// check that every transport parameter was sent at most once
sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] })
for i := 0; i < len(parameterIDs)-1; i++ {
if parameterIDs[i] == parameterIDs[i+1] {
return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i])
}
params.MaxStreams = v
}
if value, ok := tags[TagICSL]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second)
if r.Len() != 0 {
return fmt.Errorf("should have read all data. Still have %d bytes", r.Len())
}
if value, ok := tags[TagSFCW]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.StreamFlowControlWindow = protocol.ByteCount(v)
}
if value, ok := tags[TagCFCW]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.ConnectionFlowControlWindow = protocol.ByteCount(v)
}
return params, nil
return nil
}
// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake.
func (p *TransportParameters) getHelloMap() map[Tag][]byte {
sfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow))
cfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow))
mids := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(mids, p.MaxStreams)
icsl := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second))
tags := map[Tag][]byte{
TagICSL: icsl.Bytes(),
TagMIDS: mids.Bytes(),
TagCFCW: cfcw.Bytes(),
TagSFCW: sfcw.Bytes(),
func (p *TransportParameters) readNumericTransportParameter(
r *bytes.Reader,
paramID transportParameterID,
expectedLen int,
) error {
remainingLen := r.Len()
val, err := utils.ReadVarInt(r)
if err != nil {
return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err)
}
if p.OmitConnectionID {
tags[TagTCID] = []byte{0, 0, 0, 0}
if remainingLen-r.Len() != expectedLen {
return fmt.Errorf("inconsistent transport parameter length for %d", paramID)
}
return tags
}
func (p *TransportParameters) unmarshal(data []byte) error {
var foundIdleTimeout bool
for len(data) >= 4 {
paramID := binary.BigEndian.Uint16(data[:2])
paramLen := int(binary.BigEndian.Uint16(data[2:4]))
data = data[4:]
if len(data) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(data), paramLen)
switch paramID {
case initialMaxStreamDataBidiLocalParameterID:
p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val)
case initialMaxStreamDataBidiRemoteParameterID:
p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val)
case initialMaxStreamDataUniParameterID:
p.InitialMaxStreamDataUni = protocol.ByteCount(val)
case initialMaxDataParameterID:
p.InitialMaxData = protocol.ByteCount(val)
case initialMaxStreamsBidiParameterID:
p.MaxBidiStreams = val
case initialMaxStreamsUniParameterID:
p.MaxUniStreams = val
case idleTimeoutParameterID:
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Second)
case maxPacketSizeParameterID:
if val < 1200 {
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val)
}
switch transportParameterID(paramID) {
case initialMaxStreamDataParameterID:
if paramLen != 4 {
return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", paramLen)
}
p.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
case initialMaxDataParameterID:
if paramLen != 4 {
return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", paramLen)
}
p.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
case initialMaxBidiStreamsParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", paramLen)
}
p.MaxBidiStreams = binary.BigEndian.Uint16(data[:2])
case initialMaxUniStreamsParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", paramLen)
}
p.MaxUniStreams = binary.BigEndian.Uint16(data[:2])
case idleTimeoutParameterID:
foundIdleTimeout = true
if paramLen != 2 {
return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", paramLen)
}
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(data[:2]))*time.Second)
case maxPacketSizeParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", paramLen)
}
maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(data[:2]))
if maxPacketSize < 1200 {
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
}
p.MaxPacketSize = maxPacketSize
case disableMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
}
p.DisableMigration = true
case statelessResetTokenParameterID:
if paramLen != 16 {
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
p.StatelessResetToken = data[:16]
}
data = data[paramLen:]
}
if len(data) != 0 {
return fmt.Errorf("should have read all data. Still have %d bytes", len(data))
}
if !foundIdleTimeout {
return errors.New("missing parameter")
p.MaxPacketSize = protocol.ByteCount(val)
default:
return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID)
}
return nil
}
func (p *TransportParameters) marshal(b *bytes.Buffer) {
// initial_max_stream_data
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataParameterID))
utils.BigEndian.WriteUint16(b, 4)
utils.BigEndian.WriteUint32(b, uint32(p.StreamFlowControlWindow))
// initial_max_stream_data_bidi_local
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiLocalParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataBidiLocal))))
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataBidiLocal))
// initial_max_stream_data_bidi_remote
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiRemoteParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataBidiRemote))))
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataBidiRemote))
// initial_max_stream_data_uni
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataUniParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataUni))))
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataUni))
// initial_max_data
utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID))
utils.BigEndian.WriteUint16(b, 4)
utils.BigEndian.WriteUint32(b, uint32(p.ConnectionFlowControlWindow))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxData))))
utils.WriteVarInt(b, uint64(p.InitialMaxData))
// initial_max_bidi_streams
utils.BigEndian.WriteUint16(b, uint16(initialMaxBidiStreamsParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, p.MaxBidiStreams)
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamsBidiParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(p.MaxBidiStreams)))
utils.WriteVarInt(b, p.MaxBidiStreams)
// initial_max_uni_streams
utils.BigEndian.WriteUint16(b, uint16(initialMaxUniStreamsParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, p.MaxUniStreams)
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamsUniParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(p.MaxUniStreams)))
utils.WriteVarInt(b, p.MaxUniStreams)
// idle_timeout
utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, uint16(p.IdleTimeout/time.Second))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.IdleTimeout/time.Second))))
utils.WriteVarInt(b, uint64(p.IdleTimeout/time.Second))
// max_packet_size
utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, uint16(protocol.MaxReceivePacketSize))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(protocol.MaxReceivePacketSize))))
utils.WriteVarInt(b, uint64(protocol.MaxReceivePacketSize))
// disable_migration
if p.DisableMigration {
utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID))
@ -200,10 +196,15 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) {
utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes
b.Write(p.StatelessResetToken)
}
// original_connection_id
if p.OriginalConnectionID.Len() > 0 {
utils.BigEndian.WriteUint16(b, uint16(originalConnectionIDParameterID))
utils.BigEndian.WriteUint16(b, uint16(p.OriginalConnectionID.Len()))
b.Write(p.OriginalConnectionID.Bytes())
}
}
// String returns a string representation, intended for logging.
// It should only used for IETF QUIC.
func (p *TransportParameters) String() string {
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
return fmt.Sprintf("&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
}

View File

@ -86,30 +86,6 @@ func (mr *MockSentPacketHandlerMockRecorder) GetLowestPacketNotConfirmedAcked()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked))
}
// GetPacketNumberLen mocks base method
func (m *MockSentPacketHandler) GetPacketNumberLen(arg0 protocol.PacketNumber) protocol.PacketNumberLen {
ret := m.ctrl.Call(m, "GetPacketNumberLen", arg0)
ret0, _ := ret[0].(protocol.PacketNumberLen)
return ret0
}
// GetPacketNumberLen indicates an expected call of GetPacketNumberLen
func (mr *MockSentPacketHandlerMockRecorder) GetPacketNumberLen(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPacketNumberLen", reflect.TypeOf((*MockSentPacketHandler)(nil).GetPacketNumberLen), arg0)
}
// GetStopWaitingFrame mocks base method
func (m *MockSentPacketHandler) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame {
ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0)
ret0, _ := ret[0].(*wire.StopWaitingFrame)
return ret0
}
// GetStopWaitingFrame indicates an expected call of GetStopWaitingFrame
func (mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStopWaitingFrame", reflect.TypeOf((*MockSentPacketHandler)(nil).GetStopWaitingFrame), arg0)
}
// OnAlarm mocks base method
func (m *MockSentPacketHandler) OnAlarm() error {
ret := m.ctrl.Call(m, "OnAlarm")
@ -122,6 +98,31 @@ func (mr *MockSentPacketHandlerMockRecorder) OnAlarm() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAlarm", reflect.TypeOf((*MockSentPacketHandler)(nil).OnAlarm))
}
// PeekPacketNumber mocks base method
func (m *MockSentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
ret := m.ctrl.Call(m, "PeekPacketNumber")
ret0, _ := ret[0].(protocol.PacketNumber)
ret1, _ := ret[1].(protocol.PacketNumberLen)
return ret0, ret1
}
// PeekPacketNumber indicates an expected call of PeekPacketNumber
func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber))
}
// PopPacketNumber mocks base method
func (m *MockSentPacketHandler) PopPacketNumber() protocol.PacketNumber {
ret := m.ctrl.Call(m, "PopPacketNumber")
ret0, _ := ret[0].(protocol.PacketNumber)
return ret0
}
// PopPacketNumber indicates an expected call of PopPacketNumber
func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber))
}
// ReceivedAck mocks base method
func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.PacketNumber, arg2 protocol.EncryptionLevel, arg3 time.Time) error {
ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2, arg3)

View File

@ -0,0 +1,149 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: CryptoSetup)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockCryptoSetup is a mock of CryptoSetup interface
type MockCryptoSetup struct {
ctrl *gomock.Controller
recorder *MockCryptoSetupMockRecorder
}
// MockCryptoSetupMockRecorder is the mock recorder for MockCryptoSetup
type MockCryptoSetupMockRecorder struct {
mock *MockCryptoSetup
}
// NewMockCryptoSetup creates a new mock instance
func NewMockCryptoSetup(ctrl *gomock.Controller) *MockCryptoSetup {
mock := &MockCryptoSetup{ctrl: ctrl}
mock.recorder = &MockCryptoSetupMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder {
return m.recorder
}
// Close mocks base method
func (m *MockCryptoSetup) Close() error {
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close))
}
// ConnectionState mocks base method
func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState {
ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(handshake.ConnectionState)
return ret0
}
// ConnectionState indicates an expected call of ConnectionState
func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
}
// GetSealer mocks base method
func (m *MockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) {
ret := m.ctrl.Call(m, "GetSealer")
ret0, _ := ret[0].(protocol.EncryptionLevel)
ret1, _ := ret[1].(handshake.Sealer)
return ret0, ret1
}
// GetSealer indicates an expected call of GetSealer
func (mr *MockCryptoSetupMockRecorder) GetSealer() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetSealer))
}
// GetSealerWithEncryptionLevel mocks base method
func (m *MockCryptoSetup) GetSealerWithEncryptionLevel(arg0 protocol.EncryptionLevel) (handshake.Sealer, error) {
ret := m.ctrl.Call(m, "GetSealerWithEncryptionLevel", arg0)
ret0, _ := ret[0].(handshake.Sealer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSealerWithEncryptionLevel indicates an expected call of GetSealerWithEncryptionLevel
func (mr *MockCryptoSetupMockRecorder) GetSealerWithEncryptionLevel(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealerWithEncryptionLevel", reflect.TypeOf((*MockCryptoSetup)(nil).GetSealerWithEncryptionLevel), arg0)
}
// HandleMessage mocks base method
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// HandleMessage indicates an expected call of HandleMessage
func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
}
// Open1RTT mocks base method
func (m *MockCryptoSetup) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open1RTT indicates an expected call of Open1RTT
func (mr *MockCryptoSetupMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockCryptoSetup)(nil).Open1RTT), arg0, arg1, arg2, arg3)
}
// OpenHandshake mocks base method
func (m *MockCryptoSetup) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenHandshake indicates an expected call of OpenHandshake
func (mr *MockCryptoSetupMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).OpenHandshake), arg0, arg1, arg2, arg3)
}
// OpenInitial mocks base method
func (m *MockCryptoSetup) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenInitial indicates an expected call of OpenInitial
func (mr *MockCryptoSetupMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockCryptoSetup)(nil).OpenInitial), arg0, arg1, arg2, arg3)
}
// RunHandshake mocks base method
func (m *MockCryptoSetup) RunHandshake() error {
ret := m.ctrl.Call(m, "RunHandshake")
ret0, _ := ret[0].(error)
return ret0
}
// RunHandshake indicates an expected call of RunHandshake
func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake))
}

View File

@ -1,6 +1,7 @@
package mocks
//go:generate sh -c "../mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"

View File

@ -0,0 +1,59 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Sealer)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockSealer is a mock of Sealer interface
type MockSealer struct {
ctrl *gomock.Controller
recorder *MockSealerMockRecorder
}
// MockSealerMockRecorder is the mock recorder for MockSealer
type MockSealerMockRecorder struct {
mock *MockSealer
}
// NewMockSealer creates a new mock instance
func NewMockSealer(ctrl *gomock.Controller) *MockSealer {
mock := &MockSealer{ctrl: ctrl}
mock.recorder = &MockSealerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockSealer) EXPECT() *MockSealerMockRecorder {
return m.recorder
}
// Overhead mocks base method
func (m *MockSealer) Overhead() int {
ret := m.ctrl.Call(m, "Overhead")
ret0, _ := ret[0].(int)
return ret0
}
// Overhead indicates an expected call of Overhead
func (mr *MockSealerMockRecorder) Overhead() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockSealer)(nil).Overhead))
}
// Seal mocks base method
func (m *MockSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte {
ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
return ret0
}
// Seal indicates an expected call of Seal
func (mr *MockSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockSealer)(nil).Seal), arg0, arg1, arg2, arg3)
}

View File

@ -1,72 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: TLSExtensionHandler)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
mint "github.com/bifurcation/mint"
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
)
// MockTLSExtensionHandler is a mock of TLSExtensionHandler interface
type MockTLSExtensionHandler struct {
ctrl *gomock.Controller
recorder *MockTLSExtensionHandlerMockRecorder
}
// MockTLSExtensionHandlerMockRecorder is the mock recorder for MockTLSExtensionHandler
type MockTLSExtensionHandlerMockRecorder struct {
mock *MockTLSExtensionHandler
}
// NewMockTLSExtensionHandler creates a new mock instance
func NewMockTLSExtensionHandler(ctrl *gomock.Controller) *MockTLSExtensionHandler {
mock := &MockTLSExtensionHandler{ctrl: ctrl}
mock.recorder = &MockTLSExtensionHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockTLSExtensionHandler) EXPECT() *MockTLSExtensionHandlerMockRecorder {
return m.recorder
}
// GetPeerParams mocks base method
func (m *MockTLSExtensionHandler) GetPeerParams() <-chan handshake.TransportParameters {
ret := m.ctrl.Call(m, "GetPeerParams")
ret0, _ := ret[0].(<-chan handshake.TransportParameters)
return ret0
}
// GetPeerParams indicates an expected call of GetPeerParams
func (mr *MockTLSExtensionHandlerMockRecorder) GetPeerParams() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerParams", reflect.TypeOf((*MockTLSExtensionHandler)(nil).GetPeerParams))
}
// Receive mocks base method
func (m *MockTLSExtensionHandler) Receive(arg0 mint.HandshakeType, arg1 *mint.ExtensionList) error {
ret := m.ctrl.Call(m, "Receive", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Receive indicates an expected call of Receive
func (mr *MockTLSExtensionHandlerMockRecorder) Receive(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Receive), arg0, arg1)
}
// Send mocks base method
func (m *MockTLSExtensionHandler) Send(arg0 mint.HandshakeType, arg1 *mint.ExtensionList) error {
ret := m.ctrl.Call(m, "Send", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Send indicates an expected call of Send
func (mr *MockTLSExtensionHandlerMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Send), arg0, arg1)
}

View File

@ -7,22 +7,22 @@ type EncryptionLevel int
const (
// EncryptionUnspecified is a not specified encryption level
EncryptionUnspecified EncryptionLevel = iota
// EncryptionUnencrypted is not encrypted
EncryptionUnencrypted
// EncryptionSecure is encrypted, but not forward secure
EncryptionSecure
// EncryptionForwardSecure is forward secure
EncryptionForwardSecure
// EncryptionInitial is the Initial encryption level
EncryptionInitial
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake
// Encryption1RTT is the 1-RTT encryption level
Encryption1RTT
)
func (e EncryptionLevel) String() string {
switch e {
case EncryptionUnencrypted:
return "unencrypted"
case EncryptionSecure:
return "encrypted (not forward-secure)"
case EncryptionForwardSecure:
return "forward-secure"
case EncryptionInitial:
return "Initial"
case EncryptionHandshake:
return "Handshake"
case Encryption1RTT:
return "1-RTT"
}
return "unknown"
}

View File

@ -8,17 +8,13 @@ func InferPacketNumber(
version VersionNumber,
) PacketNumber {
var epochDelta PacketNumber
if version.UsesVarintPacketNumbers() {
switch packetNumberLength {
case PacketNumberLen1:
epochDelta = PacketNumber(1) << 7
case PacketNumberLen2:
epochDelta = PacketNumber(1) << 14
case PacketNumberLen4:
epochDelta = PacketNumber(1) << 30
}
} else {
epochDelta = PacketNumber(1) << (uint8(packetNumberLength) * 8)
switch packetNumberLength {
case PacketNumberLen1:
epochDelta = PacketNumber(1) << 7
case PacketNumberLen2:
epochDelta = PacketNumber(1) << 14
case PacketNumberLen4:
epochDelta = PacketNumber(1) << 30
}
epoch := lastPacketNumber & ^(epochDelta - 1)
prevEpochBegin := epoch - epochDelta
@ -48,8 +44,7 @@ func delta(a, b PacketNumber) PacketNumber {
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
diff := uint64(packetNumber - leastUnacked)
if version.UsesVarintPacketNumbers() && diff < (1<<(14-1)) ||
!version.UsesVarintPacketNumbers() && diff < (1<<(16-1)) {
if diff < (1 << (14 - 1)) {
return PacketNumberLen2
}
return PacketNumberLen4
@ -63,8 +58,5 @@ func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen {
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
return PacketNumberLen2
}
if packetNumber < (1 << (uint8(PacketNumberLen4) * 8)) {
return PacketNumberLen4
}
return PacketNumberLen6
return PacketNumberLen4
}

View File

@ -8,6 +8,9 @@ const MaxPacketSizeIPv4 = 1252
// MaxPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
const MaxPacketSizeIPv6 = 1232
// MinStatelessResetSize is the minimum size of a stateless reset packet
const MinStatelessResetSize = 1 + 20 + 16
// NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet
// This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames
const NonForwardSecurePacketSizeReduction = 50
@ -24,38 +27,22 @@ const InitialCongestionWindow ByteCount = 32 * DefaultTCPMSS
// session queues for later until it sends a public reset.
const MaxUndecryptablePackets = 10
// PublicResetTimeout is the time to wait before sending a Public Reset when receiving too many undecryptable packets during the handshake
// This timeout allows the Go scheduler to switch to the Go rountine that reads the crypto stream and to escalate the crypto
const PublicResetTimeout = 500 * time.Millisecond
// ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data
// This is the value that Google servers are using
const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB
// ReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data
// This is the value that Google servers are using
const ReceiveConnectionFlowControlWindow = (1 << 10) * 48 // 48 kB
// DefaultMaxReceiveStreamFlowControlWindowServer is the default maximum stream-level flow control window for receiving data, for the server
// This is the value that Google servers are using
const DefaultMaxReceiveStreamFlowControlWindowServer = 1 * (1 << 20) // 1 MB
// DefaultMaxReceiveConnectionFlowControlWindowServer is the default connection-level flow control window for receiving data, for the server
// This is the value that Google servers are using
const DefaultMaxReceiveConnectionFlowControlWindowServer = 1.5 * (1 << 20) // 1.5 MB
// DefaultMaxReceiveStreamFlowControlWindowClient is the default maximum stream-level flow control window for receiving data, for the client
// This is the value that Chromium is using
const DefaultMaxReceiveStreamFlowControlWindowClient = 6 * (1 << 20) // 6 MB
// DefaultMaxReceiveConnectionFlowControlWindowClient is the default connection-level flow control window for receiving data, for the client
// This is the value that Google servers are using
const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15 MB
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
// This is the value that Chromium is using
const ConnectionFlowControlMultiplier = 1.5
// InitialMaxStreamData is the stream-level flow control window for receiving data
const InitialMaxStreamData = (1 << 10) * 512 // 512 kb
// InitialMaxData is the connection-level flow control window for receiving data
const InitialMaxData = ConnectionFlowControlMultiplier * InitialMaxStreamData
// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data, for the server
const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB
// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data, for the server
const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 12 MB
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
const WindowUpdateThreshold = 0.25
@ -65,12 +52,6 @@ const DefaultMaxIncomingStreams = 100
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
const DefaultMaxIncomingUniStreams = 100
// MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used.
const MaxStreamsMultiplier = 1.1
// MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used.
const MaxStreamsMinimumIncrement = 10
// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed.
const MaxSessionUnprocessedPackets = defaultMaxCongestionWindowPackets
@ -103,15 +84,9 @@ const MaxNonRetransmittableAcks = 19
// prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000
// CryptoMaxParams is the upper limit for the number of parameters in a crypto message.
// Value taken from Chrome.
const CryptoMaxParams = 128
// CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message.
const CryptoParameterMaxLength = 4000
// EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX.
const EphermalKeyLifetime = time.Minute
// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
// This limits the size of the ClientHello and Certificates that can be received.
const MaxCryptoStreamOffset = 16 * (1 << 10)
// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout
const MinRemoteIdleTimeout = 5 * time.Second
@ -122,12 +97,9 @@ const DefaultIdleTimeout = 30 * time.Second
// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
const DefaultHandshakeTimeout = 10 * time.Second
// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed
// RetiredConnectionIDDeleteTimeout is the time we keep closed sessions around in order to retransmit the CONNECTION_CLOSE.
// after this time all information about the old connection will be deleted
const ClosedSessionDeleteTimeout = time.Minute
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
const NumCachedCertificates = 128
const RetiredConnectionIDDeleteTimeout = 5 * time.Second
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
@ -135,7 +107,7 @@ const NumCachedCertificates = 128
// 2. it reduces the head-of-line blocking, when a packet is lost
const MinStreamFrameSize ByteCount = 128
// MaxAckFrameSize is the maximum size for an (IETF QUIC) ACK frame that we write
// MaxAckFrameSize is the maximum size for an ACK frame that we write
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
// The MaxAckFrameSize should be large enough to encode many ACK range,
// but must ensure that a maximum size ACK frame fits into one packet.
@ -149,6 +121,3 @@ const MinPacingDelay time.Duration = 100 * time.Microsecond
// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
// if no other value is configured.
const DefaultConnectionIDLength = 4
// MaxRetries is the maximum number of Retries a client will do before failing the connection.
const MaxRetries = 3

View File

@ -19,11 +19,9 @@ const (
PacketNumberLen2 PacketNumberLen = 2
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 PacketNumberLen = 4
// PacketNumberLen6 is a packet number length of 6 bytes
PacketNumberLen6 PacketNumberLen = 6
)
// The PacketType is the Long Header Type (only used for the IETF draft header format)
// The PacketType is the Long Header Type
type PacketType uint8
const (
@ -71,10 +69,7 @@ const MaxReceivePacketSize ByteCount = 1452 - 64
// Used in QUIC for congestion window computations in bytes.
const DefaultTCPMSS ByteCount = 1460
// MinClientHelloSize is the minimum size the server expects an inchoate CHLO to have (in gQUIC)
const MinClientHelloSize = 1024
// MinInitialPacketSize is the minimum size an Initial packet (in IETF QUIC) is required to have.
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
const MinInitialPacketSize = 1200
// MaxClientHellos is the maximum number of times we'll send a client hello
@ -83,8 +78,5 @@ const MinInitialPacketSize = 1200
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
const MaxClientHellos = 3
// ConnectionIDLenGQUIC is the length of the source Connection ID used on gQUIC QUIC packets.
const ConnectionIDLenGQUIC = 8
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8

View File

@ -3,34 +3,65 @@ package protocol
// A StreamID in QUIC
type StreamID uint64
// MaxBidiStreamID is the highest stream ID that the peer is allowed to open,
// when it is allowed to open numStreams bidirectional streams.
// It is only valid for IETF QUIC.
func MaxBidiStreamID(numStreams int, pers Perspective) StreamID {
// StreamType encodes if this is a unidirectional or bidirectional stream
type StreamType uint8
const (
// StreamTypeUni is a unidirectional stream
StreamTypeUni StreamType = iota
// StreamTypeBidi is a bidirectional stream
StreamTypeBidi
)
// InitiatedBy says if the stream was initiated by the client or by the server
func (s StreamID) InitiatedBy() Perspective {
if s%2 == 0 {
return PerspectiveClient
}
return PerspectiveServer
}
//Type says if this is a unidirectional or bidirectional stream
func (s StreamID) Type() StreamType {
if s%4 >= 2 {
return StreamTypeUni
}
return StreamTypeBidi
}
// StreamNum returns how many streams in total are below this
// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9)
func (s StreamID) StreamNum() uint64 {
return uint64(s/4) + 1
}
// MaxStreamID is the highest stream ID that a peer is allowed to open,
// when it is allowed to open numStreams.
func MaxStreamID(stype StreamType, numStreams uint64, pers Perspective) StreamID {
if numStreams == 0 {
return 0
}
var first StreamID
if pers == PerspectiveClient {
first = 1
} else {
first = 4
switch stype {
case StreamTypeBidi:
switch pers {
case PerspectiveClient:
first = 0
case PerspectiveServer:
first = 1
}
case StreamTypeUni:
switch pers {
case PerspectiveClient:
first = 2
case PerspectiveServer:
first = 3
}
}
return first + 4*StreamID(numStreams-1)
}
// MaxUniStreamID is the highest stream ID that the peer is allowed to open,
// when it is allowed to open numStreams unidirectional streams.
// It is only valid for IETF QUIC.
func MaxUniStreamID(numStreams int, pers Perspective) StreamID {
if numStreams == 0 {
return 0
}
var first StreamID
if pers == PerspectiveClient {
first = 3
} else {
first = 2
}
return first + 4*StreamID(numStreams-1)
// FirstStream returns the first valid stream ID
func FirstStream(stype StreamType, pers Perspective) StreamID {
return MaxStreamID(stype, 1, pers)
}

View File

@ -18,32 +18,18 @@ const (
// The version numbers, making grepping easier
const (
Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9
Version43 VersionNumber = gquicVersion0 + 4*0x100 + 0x3
Version44 VersionNumber = gquicVersion0 + 4*0x100 + 0x4
VersionTLS VersionNumber = 101
VersionWhatever VersionNumber = 0 // for when the version doesn't matter
VersionWhatever VersionNumber = 1 // for when the version doesn't matter
VersionUnknown VersionNumber = math.MaxUint32
VersionMilestone0_10_0 VersionNumber = 0x51474f02
)
// SupportedVersions lists the versions that the server supports
// must be in sorted descending order
var SupportedVersions = []VersionNumber{
Version44,
Version43,
Version39,
}
var SupportedVersions = []VersionNumber{VersionTLS}
// IsValidVersion says if the version is known to quic-go
func IsValidVersion(v VersionNumber) bool {
return v == VersionTLS || v == VersionMilestone0_10_0 || IsSupportedVersion(SupportedVersions, v)
}
// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake
func (vn VersionNumber) UsesTLS() bool {
return !vn.isGQUIC()
return v == VersionTLS || IsSupportedVersion(SupportedVersions, v)
}
func (vn VersionNumber) String() string {
@ -52,8 +38,6 @@ func (vn VersionNumber) String() string {
return "whatever"
case VersionUnknown:
return "unknown"
case VersionMilestone0_10_0:
return "quic-go Milestone 0.10.0"
case VersionTLS:
return "TLS dev version (WIP)"
default:
@ -66,61 +50,9 @@ func (vn VersionNumber) String() string {
// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters
func (vn VersionNumber) ToAltSvc() string {
if vn.isGQUIC() {
return fmt.Sprintf("%d", vn.toGQUICVersion())
}
return fmt.Sprintf("%d", vn)
}
// CryptoStreamID gets the Stream ID of the crypto stream
func (vn VersionNumber) CryptoStreamID() StreamID {
if vn.isGQUIC() {
return 1
}
return 0
}
// UsesIETFFrameFormat tells if this version uses the IETF frame format
func (vn VersionNumber) UsesIETFFrameFormat() bool {
return !vn.isGQUIC()
}
// UsesIETFHeaderFormat tells if this version uses the IETF header format
func (vn VersionNumber) UsesIETFHeaderFormat() bool {
return !vn.isGQUIC() || vn >= Version44
}
// UsesLengthInHeader tells if this version uses the Length field in the IETF header
func (vn VersionNumber) UsesLengthInHeader() bool {
return !vn.isGQUIC()
}
// UsesTokenInHeader tells if this version uses the Token field in the IETF header
func (vn VersionNumber) UsesTokenInHeader() bool {
return !vn.isGQUIC()
}
// UsesStopWaitingFrames tells if this version uses STOP_WAITING frames
func (vn VersionNumber) UsesStopWaitingFrames() bool {
return vn.isGQUIC() && vn <= Version43
}
// UsesVarintPacketNumbers tells if this version uses 7/14/30 bit packet numbers
func (vn VersionNumber) UsesVarintPacketNumbers() bool {
return !vn.isGQUIC()
}
// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control
func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool {
if id == vn.CryptoStreamID() {
return false
}
if vn.isGQUIC() && id == 3 {
return false
}
return true
}
func (vn VersionNumber) isGQUIC() bool {
return vn > gquicVersion0 && vn <= maxGquicVersion
}

View File

@ -5,7 +5,7 @@ import (
)
// ErrorCode can be used as a normal error without reason.
type ErrorCode uint32
type ErrorCode uint16
func (e ErrorCode) Error() string {
return e.String()

View File

@ -13,13 +13,6 @@ type ByteOrder interface {
ReadUint16(io.ByteReader) (uint16, error)
WriteUint64(*bytes.Buffer, uint64)
WriteUint56(*bytes.Buffer, uint64)
WriteUint48(*bytes.Buffer, uint64)
WriteUint40(*bytes.Buffer, uint64)
WriteUint32(*bytes.Buffer, uint32)
WriteUint24(*bytes.Buffer, uint32)
WriteUint16(*bytes.Buffer, uint16)
ReadUfloat16(io.ByteReader) (uint64, error)
WriteUfloat16(*bytes.Buffer, uint64)
}

View File

@ -2,7 +2,6 @@ package utils
import (
"bytes"
"fmt"
"io"
)
@ -97,61 +96,12 @@ func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
})
}
// WriteUint56 writes 56 bit of a uint64
func (bigEndian) WriteUint56(b *bytes.Buffer, i uint64) {
if i >= (1 << 56) {
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
}
b.Write([]byte{
uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
}
// WriteUint48 writes 48 bit of a uint64
func (bigEndian) WriteUint48(b *bytes.Buffer, i uint64) {
if i >= (1 << 48) {
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
}
b.Write([]byte{
uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
}
// WriteUint40 writes 40 bit of a uint64
func (bigEndian) WriteUint40(b *bytes.Buffer, i uint64) {
if i >= (1 << 40) {
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
}
b.Write([]byte{
uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
}
// WriteUint32 writes a uint32
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint24 writes 24 bit of a uint32
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
if i >= (1 << 24) {
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
}
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint16 writes a uint16
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i >> 8), uint8(i)})
}
func (l bigEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
return readUfloat16(b, l)
}
func (l bigEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
writeUfloat16(b, l, val)
}

View File

@ -1,157 +0,0 @@
package utils
import (
"bytes"
"fmt"
"io"
)
// LittleEndian is the little-endian implementation of ByteOrder.
var LittleEndian ByteOrder = littleEndian{}
type littleEndian struct{}
var _ ByteOrder = &littleEndian{}
// ReadUintN reads N bytes
func (littleEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
var res uint64
for i := uint8(0); i < length; i++ {
bt, err := b.ReadByte()
if err != nil {
return 0, err
}
res ^= uint64(bt) << (i * 8)
}
return res, nil
}
// ReadUint64 reads a uint64
func (littleEndian) ReadUint64(b io.ByteReader) (uint64, error) {
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b5, err = b.ReadByte(); err != nil {
return 0, err
}
if b6, err = b.ReadByte(); err != nil {
return 0, err
}
if b7, err = b.ReadByte(); err != nil {
return 0, err
}
if b8, err = b.ReadByte(); err != nil {
return 0, err
}
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
}
// ReadUint32 reads a uint32
func (littleEndian) ReadUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
}
// ReadUint16 reads a uint16
func (littleEndian) ReadUint16(b io.ByteReader) (uint16, error) {
var b1, b2 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
return uint16(b1) + uint16(b2)<<8, nil
}
// WriteUint64 writes a uint64
func (littleEndian) WriteUint64(b *bytes.Buffer, i uint64) {
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), uint8(i >> 56),
})
}
// WriteUint56 writes 56 bit of a uint64
func (littleEndian) WriteUint56(b *bytes.Buffer, i uint64) {
if i >= (1 << 56) {
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
}
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48),
})
}
// WriteUint48 writes 48 bit of a uint64
func (littleEndian) WriteUint48(b *bytes.Buffer, i uint64) {
if i >= (1 << 48) {
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
}
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
uint8(i >> 32), uint8(i >> 40),
})
}
// WriteUint40 writes 40 bit of a uint64
func (littleEndian) WriteUint40(b *bytes.Buffer, i uint64) {
if i >= (1 << 40) {
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
}
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16),
uint8(i >> 24), uint8(i >> 32),
})
}
// WriteUint32 writes a uint32
func (littleEndian) WriteUint32(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24)})
}
// WriteUint24 writes 24 bit of a uint32
func (littleEndian) WriteUint24(b *bytes.Buffer, i uint32) {
if i >= (1 << 24) {
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
}
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16)})
}
// WriteUint16 writes a uint16
func (littleEndian) WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i), uint8(i >> 8)})
}
func (l littleEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
return readUfloat16(b, l)
}
func (l littleEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
writeUfloat16(b, l, val)
}

View File

@ -1,86 +0,0 @@
package utils
import (
"bytes"
"io"
"math"
)
// We define an unsigned 16-bit floating point value, inspired by IEEE floats
// (http://en.wikipedia.org/wiki/Half_precision_floating-point_format),
// with 5-bit exponent (bias 1), 11-bit mantissa (effective 12 with hidden
// bit) and denormals, but without signs, transfinites or fractions. Wire format
// 16 bits (little-endian byte order) are split into exponent (high 5) and
// mantissa (low 11) and decoded as:
// uint64_t value;
// if (exponent == 0) value = mantissa;
// else value = (mantissa | 1 << 11) << (exponent - 1)
const uFloat16ExponentBits = 5
const uFloat16MaxExponent = (1 << uFloat16ExponentBits) - 2 // 30
const uFloat16MantissaBits = 16 - uFloat16ExponentBits // 11
const uFloat16MantissaEffectiveBits = uFloat16MantissaBits + 1 // 12
const uFloat16MaxValue = ((uint64(1) << uFloat16MantissaEffectiveBits) - 1) << uFloat16MaxExponent // 0x3FFC0000000
// readUfloat16 reads a float in the QUIC-float16 format and returns its uint64 representation
func readUfloat16(b io.ByteReader, byteOrder ByteOrder) (uint64, error) {
val, err := byteOrder.ReadUint16(b)
if err != nil {
return 0, err
}
res := uint64(val)
if res < (1 << uFloat16MantissaEffectiveBits) {
// Fast path: either the value is denormalized (no hidden bit), or
// normalized (hidden bit set, exponent offset by one) with exponent zero.
// Zero exponent offset by one sets the bit exactly where the hidden bit is.
// So in both cases the value encodes itself.
return res, nil
}
exponent := val >> uFloat16MantissaBits // No sign extend on uint!
// After the fast pass, the exponent is at least one (offset by one).
// Un-offset the exponent.
exponent--
// Here we need to clear the exponent and set the hidden bit. We have already
// decremented the exponent, so when we subtract it, it leaves behind the
// hidden bit.
res -= uint64(exponent) << uFloat16MantissaBits
res <<= exponent
return res, nil
}
// writeUfloat16 writes a float in the QUIC-float16 format from its uint64 representation
func writeUfloat16(b *bytes.Buffer, byteOrder ByteOrder, value uint64) {
var result uint16
if value < (uint64(1) << uFloat16MantissaEffectiveBits) {
// Fast path: either the value is denormalized, or has exponent zero.
// Both cases are represented by the value itself.
result = uint16(value)
} else if value >= uFloat16MaxValue {
// Value is out of range; clamp it to the maximum representable.
result = math.MaxUint16
} else {
// The highest bit is between position 13 and 42 (zero-based), which
// corresponds to exponent 1-30. In the output, mantissa is from 0 to 10,
// hidden bit is 11 and exponent is 11 to 15. Shift the highest bit to 11
// and count the shifts.
exponent := uint16(0)
for offset := uint16(16); offset > 0; offset /= 2 {
// Right-shift the value until the highest bit is in position 11.
// For offset of 16, 8, 4, 2 and 1 (binary search over 1-30),
// shift if the bit is at or above 11 + offset.
if value >= (uint64(1) << (uFloat16MantissaBits + offset)) {
exponent += offset
value >>= offset
}
}
// Hidden bit (position 11) is set. We should remove it and increment the
// exponent. Equivalently, we just add it to the exponent.
// This hides the bit.
result = (uint16(value) + (exponent << uFloat16MantissaBits))
}
byteOrder.WriteUint16(b, result)
}

View File

@ -13,29 +13,21 @@ import (
// TODO: use the value sent in the transport parameters
const ackDelayExponent = 3
var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
// An AckFrame is an ACK frame
type AckFrame struct {
AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last
DelayTime time.Duration
}
func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
return parseAckOrAckEcnFrame(r, false, version)
}
func parseAckEcnFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
return parseAckOrAckEcnFrame(r, true, version)
}
// parseAckFrame reads an ACK frame
func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNumber) (*AckFrame, error) {
if !version.UsesIETFFrameFormat() {
return parseAckFrameLegacy(r, version)
}
if _, err := r.ReadByte(); err != nil {
func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
ecn := typeByte&0x1 > 0
frame := &AckFrame{}
@ -50,14 +42,6 @@ func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNu
}
frame.DelayTime = time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
if ecn {
for i := 0; i < 3; i++ {
if _, err := utils.ReadVarInt(r); err != nil {
return nil, err
}
}
}
numBlocks, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
@ -103,16 +87,22 @@ func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNu
if !frame.validateAckRanges() {
return nil, errInvalidAckRanges
}
// parse (and skip) the ECN section
if ecn {
for i := 0; i < 3; i++ {
if _, err := utils.ReadVarInt(r); err != nil {
return nil, err
}
}
}
return frame, nil
}
// Write writes an ACK frame.
func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
if !version.UsesIETFFrameFormat() {
return f.writeLegacy(b, version)
}
b.WriteByte(0x0d)
b.WriteByte(0x2)
utils.WriteVarInt(b, uint64(f.LargestAcked()))
utils.WriteVarInt(b, encodeAckDelay(f.DelayTime))
@ -134,10 +124,6 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
// Length of a written frame
func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() {
return f.lengthLegacy(version)
}
largestAcked := f.AckRanges[0].Largest
numRanges := f.numEncodableAckRanges()

View File

@ -1,364 +0,0 @@
package wire
import (
"bytes"
"errors"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, error) {
frame := &AckFrame{}
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
hasMissingRanges := typeByte&0x20 == 0x20
largestAckedLen := 2 * ((typeByte & 0x0C) >> 2)
if largestAckedLen == 0 {
largestAckedLen = 1
}
missingSequenceNumberDeltaLen := 2 * (typeByte & 0x03)
if missingSequenceNumberDeltaLen == 0 {
missingSequenceNumberDeltaLen = 1
}
la, err := utils.BigEndian.ReadUintN(r, largestAckedLen)
if err != nil {
return nil, err
}
largestAcked := protocol.PacketNumber(la)
delay, err := utils.BigEndian.ReadUfloat16(r)
if err != nil {
return nil, err
}
frame.DelayTime = time.Duration(delay) * time.Microsecond
var numAckBlocks uint8
if hasMissingRanges {
numAckBlocks, err = r.ReadByte()
if err != nil {
return nil, err
}
}
if hasMissingRanges && numAckBlocks == 0 {
return nil, errInvalidAckRanges
}
abl, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen)
if err != nil {
return nil, err
}
ackBlockLength := protocol.PacketNumber(abl)
if largestAcked > 0 && ackBlockLength < 1 {
return nil, errors.New("invalid first ACK range")
}
if ackBlockLength > largestAcked+1 {
return nil, errInvalidAckRanges
}
if hasMissingRanges {
ackRange := AckRange{
Smallest: largestAcked - ackBlockLength + 1,
Largest: largestAcked,
}
frame.AckRanges = append(frame.AckRanges, ackRange)
var inLongBlock bool
var lastRangeComplete bool
for i := uint8(0); i < numAckBlocks; i++ {
var gap uint8
gap, err = r.ReadByte()
if err != nil {
return nil, err
}
abl, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen)
if err != nil {
return nil, err
}
ackBlockLength := protocol.PacketNumber(abl)
if inLongBlock {
frame.AckRanges[len(frame.AckRanges)-1].Smallest -= protocol.PacketNumber(gap) + ackBlockLength
frame.AckRanges[len(frame.AckRanges)-1].Largest -= protocol.PacketNumber(gap)
} else {
lastRangeComplete = false
ackRange := AckRange{
Largest: frame.AckRanges[len(frame.AckRanges)-1].Smallest - protocol.PacketNumber(gap) - 1,
}
ackRange.Smallest = ackRange.Largest - ackBlockLength + 1
frame.AckRanges = append(frame.AckRanges, ackRange)
}
if ackBlockLength > 0 {
lastRangeComplete = true
}
inLongBlock = (ackBlockLength == 0)
}
// if the last range was not complete, First and Last make no sense
// remove the range from frame.AckRanges
if !lastRangeComplete {
frame.AckRanges = frame.AckRanges[:len(frame.AckRanges)-1]
}
} else {
frame.AckRanges = make([]AckRange, 1)
if largestAcked != 0 {
frame.AckRanges[0].Largest = largestAcked
frame.AckRanges[0].Smallest = largestAcked + 1 - ackBlockLength
}
}
if !frame.validateAckRanges() {
return nil, errInvalidAckRanges
}
var numTimestamp byte
numTimestamp, err = r.ReadByte()
if err != nil {
return nil, err
}
if numTimestamp > 0 {
// Delta Largest acked
_, err = r.ReadByte()
if err != nil {
return nil, err
}
// First Timestamp
_, err = utils.BigEndian.ReadUint32(r)
if err != nil {
return nil, err
}
for i := 0; i < int(numTimestamp)-1; i++ {
// Delta Largest acked
_, err = r.ReadByte()
if err != nil {
return nil, err
}
// Time Since Previous Timestamp
_, err = utils.BigEndian.ReadUint16(r)
if err != nil {
return nil, err
}
}
}
return frame, nil
}
func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error {
largestAcked := f.LargestAcked()
largestAckedLen := protocol.GetPacketNumberLength(largestAcked)
typeByte := uint8(0x40)
if largestAckedLen != protocol.PacketNumberLen1 {
typeByte ^= (uint8(largestAckedLen / 2)) << 2
}
missingSequenceNumberDeltaLen := f.getMissingSequenceNumberDeltaLen()
if missingSequenceNumberDeltaLen != protocol.PacketNumberLen1 {
typeByte ^= (uint8(missingSequenceNumberDeltaLen / 2))
}
if f.HasMissingRanges() {
typeByte |= 0x20
}
b.WriteByte(typeByte)
switch largestAckedLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(largestAcked))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(largestAcked))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(largestAcked))
case protocol.PacketNumberLen6:
utils.BigEndian.WriteUint48(b, uint64(largestAcked)&(1<<48-1))
}
utils.BigEndian.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond))
var numRanges uint64
var numRangesWritten uint64
if f.HasMissingRanges() {
numRanges = f.numWritableNackRanges()
if numRanges > 0xFF {
panic("AckFrame: Too many ACK ranges")
}
b.WriteByte(uint8(numRanges - 1))
}
var firstAckBlockLength protocol.PacketNumber
if !f.HasMissingRanges() {
firstAckBlockLength = largestAcked - f.LowestAcked() + 1
} else {
firstAckBlockLength = largestAcked - f.AckRanges[0].Smallest + 1
numRangesWritten++
}
switch missingSequenceNumberDeltaLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(firstAckBlockLength))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(firstAckBlockLength))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(firstAckBlockLength))
case protocol.PacketNumberLen6:
utils.BigEndian.WriteUint48(b, uint64(firstAckBlockLength)&(1<<48-1))
}
for i, ackRange := range f.AckRanges {
if i == 0 {
continue
}
length := ackRange.Largest - ackRange.Smallest + 1
gap := f.AckRanges[i-1].Smallest - ackRange.Largest - 1
num := gap/0xFF + 1
if gap%0xFF == 0 {
num--
}
if num == 1 {
b.WriteByte(uint8(gap))
switch missingSequenceNumberDeltaLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(length))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(length))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(length))
case protocol.PacketNumberLen6:
utils.BigEndian.WriteUint48(b, uint64(length)&(1<<48-1))
}
numRangesWritten++
} else {
for i := 0; i < int(num); i++ {
var lengthWritten uint64
var gapWritten uint8
if i == int(num)-1 { // last block
lengthWritten = uint64(length)
gapWritten = uint8(1 + ((gap - 1) % 255))
} else {
lengthWritten = 0
gapWritten = 0xFF
}
b.WriteByte(gapWritten)
switch missingSequenceNumberDeltaLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(lengthWritten))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(lengthWritten))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(lengthWritten))
case protocol.PacketNumberLen6:
utils.BigEndian.WriteUint48(b, lengthWritten&(1<<48-1))
}
numRangesWritten++
}
}
// this is needed if not all AckRanges can be written to the ACK frame (if there are more than 0xFF)
if numRangesWritten >= numRanges {
break
}
}
if numRanges != numRangesWritten {
return errors.New("BUG: Inconsistent number of ACK ranges written")
}
b.WriteByte(0) // no timestamps
return nil
}
func (f *AckFrame) lengthLegacy(_ protocol.VersionNumber) protocol.ByteCount {
length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp
length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked()))
missingSequenceNumberDeltaLen := protocol.ByteCount(f.getMissingSequenceNumberDeltaLen())
if f.HasMissingRanges() {
length += (1 + missingSequenceNumberDeltaLen) * protocol.ByteCount(f.numWritableNackRanges())
} else {
length += missingSequenceNumberDeltaLen
}
// we don't write
return length
}
// numWritableNackRanges calculates the number of ACK blocks that are about to be written
// this number is different from len(f.AckRanges) for the case of long gaps (> 255 packets)
func (f *AckFrame) numWritableNackRanges() uint64 {
if len(f.AckRanges) == 0 {
return 0
}
var numRanges uint64
for i, ackRange := range f.AckRanges {
if i == 0 {
continue
}
lastAckRange := f.AckRanges[i-1]
gap := lastAckRange.Smallest - ackRange.Largest - 1
rangeLength := 1 + uint64(gap)/0xFF
if uint64(gap)%0xFF == 0 {
rangeLength--
}
if numRanges+rangeLength < 0xFF {
numRanges += rangeLength
} else {
break
}
}
return numRanges + 1
}
func (f *AckFrame) getMissingSequenceNumberDeltaLen() protocol.PacketNumberLen {
var maxRangeLength protocol.PacketNumber
if f.HasMissingRanges() {
for _, ackRange := range f.AckRanges {
rangeLength := ackRange.Largest - ackRange.Smallest + 1
if rangeLength > maxRangeLength {
maxRangeLength = rangeLength
}
}
} else {
maxRangeLength = f.LargestAcked() - f.LowestAcked() + 1
}
if maxRangeLength <= 0xFF {
return protocol.PacketNumberLen1
}
if maxRangeLength <= 0xFFFF {
return protocol.PacketNumberLen2
}
if maxRangeLength <= 0xFFFFFFFF {
return protocol.PacketNumberLen4
}
return protocol.PacketNumberLen6
}

View File

@ -1,45 +0,0 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A BlockedFrame is a BLOCKED frame
type BlockedFrame struct {
Offset protocol.ByteCount
}
// parseBlockedFrame parses a BLOCKED frame
func parseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &BlockedFrame{
Offset: protocol.ByteCount(offset),
}, nil
}
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
if !version.UsesIETFFrameFormat() {
return (&blockedFrameLegacy{}).Write(b, version)
}
typeByte := uint8(0x08)
b.WriteByte(typeByte)
utils.WriteVarInt(b, uint64(f.Offset))
return nil
}
// Length of a written frame
func (f *BlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
if !version.UsesIETFFrameFormat() {
return 1 + 4
}
return 1 + utils.VarIntLen(uint64(f.Offset))
}

View File

@ -1,37 +0,0 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type blockedFrameLegacy struct {
StreamID protocol.StreamID
}
// parseBlockedFrameLegacy parses a BLOCKED frame (in gQUIC format)
// The frame returned is
// * a STREAM_BLOCKED frame, if the BLOCKED applies to a stream
// * a BLOCKED frame, if the BLOCKED applies to the connection
func parseBlockedFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err
}
streamID, err := utils.BigEndian.ReadUint32(r)
if err != nil {
return nil, err
}
if streamID == 0 {
return &BlockedFrame{}, nil
}
return &StreamBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil
}
//Write writes a BLOCKED frame
func (f *blockedFrameLegacy) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x05)
utils.BigEndian.WriteUint32(b, uint32(f.StreamID))
return nil
}

View File

@ -2,52 +2,43 @@ package wire
import (
"bytes"
"errors"
"io"
"math"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// A ConnectionCloseFrame in QUIC
// A ConnectionCloseFrame is a CONNECTION_CLOSE frame
type ConnectionCloseFrame struct {
ErrorCode qerr.ErrorCode
ReasonPhrase string
IsApplicationError bool
ErrorCode qerr.ErrorCode
ReasonPhrase string
}
// parseConnectionCloseFrame reads a CONNECTION_CLOSE frame
func parseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
var errorCode qerr.ErrorCode
var reasonPhraseLen uint64
if version.UsesIETFFrameFormat() {
ec, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return nil, err
}
errorCode = qerr.ErrorCode(ec)
reasonPhraseLen, err = utils.ReadVarInt(r)
if err != nil {
return nil, err
}
} else {
ec, err := utils.BigEndian.ReadUint32(r)
if err != nil {
return nil, err
}
errorCode = qerr.ErrorCode(ec)
length, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return nil, err
}
reasonPhraseLen = uint64(length)
f := &ConnectionCloseFrame{IsApplicationError: typeByte == 0x1d}
ec, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return nil, err
}
f.ErrorCode = qerr.ErrorCode(ec)
// read the Frame Type, if this is not an application error
if !f.IsApplicationError {
if _, err := utils.ReadVarInt(r); err != nil {
return nil, err
}
}
var reasonPhraseLen uint64
reasonPhraseLen, err = utils.ReadVarInt(r)
if err != nil {
return nil, err
}
// shortcut to prevent the unnecessary allocation of dataLen bytes
// if the dataLen is larger than the remaining length of the packet
// reading the whole reason phrase would result in EOF when attempting to READ
@ -60,37 +51,31 @@ func parseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber)
// this should never happen, since we already checked the reasonPhraseLen earlier
return nil, err
}
return &ConnectionCloseFrame{
ErrorCode: errorCode,
ReasonPhrase: string(reasonPhrase),
}, nil
f.ReasonPhrase = string(reasonPhrase)
return f, nil
}
// Length of a written frame
func (f *ConnectionCloseFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
if version.UsesIETFFrameFormat() {
return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
length := 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
if !f.IsApplicationError {
length++ // for the frame type
}
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase))
return length
}
// Write writes an CONNECTION_CLOSE frame.
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x02)
if len(f.ReasonPhrase) > math.MaxUint16 {
return errors.New("ConnectionFrame: ReasonPhrase too long")
}
if version.UsesIETFFrameFormat() {
utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode))
utils.WriteVarInt(b, uint64(len(f.ReasonPhrase)))
if f.IsApplicationError {
b.WriteByte(0x1d)
} else {
utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode))
utils.BigEndian.WriteUint16(b, uint16(len(f.ReasonPhrase)))
b.WriteByte(0x1c)
}
b.WriteString(f.ReasonPhrase)
utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode))
if !f.IsApplicationError {
utils.WriteVarInt(b, 0)
}
utils.WriteVarInt(b, uint64(len(f.ReasonPhrase)))
b.WriteString(f.ReasonPhrase)
return nil
}

View File

@ -0,0 +1,71 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A CryptoFrame is a CRYPTO frame
type CryptoFrame struct {
Offset protocol.ByteCount
Data []byte
}
func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
frame := &CryptoFrame{}
offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
frame.Offset = protocol.ByteCount(offset)
dataLen, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
if dataLen > uint64(r.Len()) {
return nil, io.EOF
}
if dataLen != 0 {
frame.Data = make([]byte, dataLen)
if _, err := io.ReadFull(r, frame.Data); err != nil {
// this should never happen, since we already checked the dataLen earlier
return nil, err
}
}
return frame, nil
}
func (f *CryptoFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x6)
utils.WriteVarInt(b, uint64(f.Offset))
utils.WriteVarInt(b, uint64(len(f.Data)))
b.Write(f.Data)
return nil
}
// Length of a written frame
func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.Offset)) + utils.VarIntLen(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data))
}
// MaxDataLen returns the maximum data length
func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount {
// pretend that the data size will be 1 bytes
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
headerLen := 1 + utils.VarIntLen(uint64(f.Offset)) + 1
if headerLen > maxSize {
return 0
}
maxDataLen := maxSize - headerLen
if utils.VarIntLen(uint64(maxDataLen)) != 1 {
maxDataLen--
}
return maxDataLen
}

View File

@ -0,0 +1,38 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// A DataBlockedFrame is a DATA_BLOCKED frame
type DataBlockedFrame struct {
DataLimit protocol.ByteCount
}
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
offset, err := utils.ReadVarInt(r)
if err != nil {
return nil, err
}
return &DataBlockedFrame{
DataLimit: protocol.ByteCount(offset),
}, nil
}
func (f *DataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
typeByte := uint8(0x14)
b.WriteByte(typeByte)
utils.WriteVarInt(b, uint64(f.DataLimit))
return nil
}
// Length of a written frame
func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1 + utils.VarIntLen(uint64(f.DataLimit))
}

View File

@ -2,16 +2,15 @@ package wire
import (
"bytes"
"errors"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/internal/qerr"
)
// ParseNextFrame parses the next frame
// It skips PADDING frames.
func ParseNextFrame(r *bytes.Reader, hdr *Header, v protocol.VersionNumber) (Frame, error) {
func ParseNextFrame(r *bytes.Reader, v protocol.VersionNumber) (Frame, error) {
for r.Len() != 0 {
typeByte, _ := r.ReadByte()
if typeByte == 0x0 { // PADDING frame
@ -19,154 +18,61 @@ func ParseNextFrame(r *bytes.Reader, hdr *Header, v protocol.VersionNumber) (Fra
}
r.UnreadByte()
if !v.UsesIETFFrameFormat() {
return parseGQUICFrame(r, typeByte, hdr, v)
}
return parseIETFFrame(r, typeByte, v)
return parseFrame(r, typeByte, v)
}
return nil, nil
}
func parseIETFFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) {
func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) {
var frame Frame
var err error
if typeByte&0xf8 == 0x10 {
if typeByte&0xf8 == 0x8 {
frame, err = parseStreamFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidStreamData, err.Error())
return nil, qerr.Error(qerr.InvalidFrameData, err.Error())
}
return frame, err
return frame, nil
}
// TODO: implement all IETF QUIC frame types
switch typeByte {
case 0x1:
frame, err = parseRstStreamFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidRstStreamData, err.Error())
}
case 0x2:
frame, err = parseConnectionCloseFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error())
}
case 0x4:
frame, err = parseMaxDataFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
}
case 0x5:
frame, err = parseMaxStreamDataFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
}
case 0x6:
frame, err = parseMaxStreamIDFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0x7:
frame, err = parsePingFrame(r, v)
case 0x8:
frame, err = parseBlockedFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0x9:
frame, err = parseStreamBlockedFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0xa:
frame, err = parseStreamIDBlockedFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0xc:
case 0x2, 0x3:
frame, err = parseAckFrame(r, v)
case 0x4:
frame, err = parseResetStreamFrame(r, v)
case 0x5:
frame, err = parseStopSendingFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0xd:
frame, err = parseAckFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidAckData, err.Error())
}
case 0xe:
frame, err = parsePathChallengeFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0xf:
frame, err = parsePathResponseFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidFrameData, err.Error())
}
case 0x1a:
frame, err = parseAckEcnFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidAckData, err.Error())
}
default:
err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte))
}
return frame, err
}
func parseGQUICFrame(r *bytes.Reader, typeByte byte, hdr *Header, v protocol.VersionNumber) (Frame, error) {
var frame Frame
var err error
if typeByte&0x80 == 0x80 {
frame, err = parseStreamFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidStreamData, err.Error())
}
return frame, err
} else if typeByte&0xc0 == 0x40 {
frame, err = parseAckFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidAckData, err.Error())
}
return frame, err
}
switch typeByte {
case 0x1:
frame, err = parseRstStreamFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidRstStreamData, err.Error())
}
case 0x2:
frame, err = parseConnectionCloseFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error())
}
case 0x3:
frame, err = parseGoawayFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidGoawayData, err.Error())
}
case 0x4:
frame, err = parseWindowUpdateFrame(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
}
case 0x5:
frame, err = parseBlockedFrameLegacy(r, v)
if err != nil {
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
}
case 0x6:
if !v.UsesStopWaitingFrames() {
err = errors.New("STOP_WAITING frames not supported by this QUIC version")
break
}
frame, err = parseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, v)
if err != nil {
err = qerr.Error(qerr.InvalidStopWaitingData, err.Error())
}
frame, err = parseCryptoFrame(r, v)
case 0x7:
frame, err = parsePingFrame(r, v)
frame, err = parseNewTokenFrame(r, v)
case 0x10:
frame, err = parseMaxDataFrame(r, v)
case 0x11:
frame, err = parseMaxStreamDataFrame(r, v)
case 0x12, 0x13:
frame, err = parseMaxStreamsFrame(r, v)
case 0x14:
frame, err = parseDataBlockedFrame(r, v)
case 0x15:
frame, err = parseStreamDataBlockedFrame(r, v)
case 0x16, 0x17:
frame, err = parseStreamsBlockedFrame(r, v)
case 0x18:
frame, err = parseNewConnectionIDFrame(r, v)
case 0x19:
frame, err = parseRetireConnectionIDFrame(r, v)
case 0x1a:
frame, err = parsePathChallengeFrame(r, v)
case 0x1b:
frame, err = parsePathResponseFrame(r, v)
case 0x1c, 0x1d:
frame, err = parseConnectionCloseFrame(r, v)
default:
err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte))
err = fmt.Errorf("unknown type byte 0x%x", typeByte)
}
return frame, err
if err != nil {
return nil, qerr.Error(qerr.InvalidFrameData, err.Error())
}
return frame, nil
}

View File

@ -1,68 +0,0 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// A GoawayFrame is a GOAWAY frame
type GoawayFrame struct {
ErrorCode qerr.ErrorCode
LastGoodStream protocol.StreamID
ReasonPhrase string
}
// parseGoawayFrame parses a GOAWAY frame
func parseGoawayFrame(r *bytes.Reader, _ protocol.VersionNumber) (*GoawayFrame, error) {
frame := &GoawayFrame{}
if _, err := r.ReadByte(); err != nil {
return nil, err
}
errorCode, err := utils.BigEndian.ReadUint32(r)
if err != nil {
return nil, err
}
frame.ErrorCode = qerr.ErrorCode(errorCode)
lastGoodStream, err := utils.BigEndian.ReadUint32(r)
if err != nil {
return nil, err
}
frame.LastGoodStream = protocol.StreamID(lastGoodStream)
reasonPhraseLen, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return nil, err
}
if reasonPhraseLen > uint16(protocol.MaxReceivePacketSize) {
return nil, qerr.Error(qerr.InvalidGoawayData, "reason phrase too long")
}
reasonPhrase := make([]byte, reasonPhraseLen)
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
return nil, err
}
frame.ReasonPhrase = string(reasonPhrase)
return frame, nil
}
func (f *GoawayFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x03)
utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode))
utils.BigEndian.WriteUint32(b, uint32(f.LastGoodStream))
utils.BigEndian.WriteUint16(b, uint16(len(f.ReasonPhrase)))
b.WriteString(f.ReasonPhrase)
return nil
}
// Length of a written frame
func (f *GoawayFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase))
}

View File

@ -3,7 +3,6 @@ package wire
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -11,10 +10,7 @@ import (
)
// Header is the header of a QUIC packet.
// It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header.
type Header struct {
IsPublicHeader bool
Raw []byte
Version protocol.VersionNumber
@ -29,12 +25,6 @@ type Header struct {
IsVersionNegotiation bool
SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
// only needed for the gQUIC Public Header
VersionFlag bool
ResetFlag bool
DiversificationNonce []byte
// only needed for the IETF Header
Type protocol.PacketType
IsLongHeader bool
KeyPhase int
@ -42,15 +32,8 @@ type Header struct {
Token []byte
}
var errInvalidPacketNumberLen = errors.New("invalid packet number length")
// Write writes the Header.
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
if !ver.UsesIETFHeaderFormat() {
h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
return h.writePublicHeader(b, pers, ver)
}
// write an IETF QUIC header
if h.IsLongHeader {
return h.writeLongHeader(b, ver)
}
@ -69,7 +52,7 @@ func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) erro
b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes())
if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
if h.Type == protocol.PacketTypeInitial {
utils.WriteVarInt(b, uint64(len(h.Token)))
b.Write(h.Token)
}
@ -89,176 +72,36 @@ func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) erro
return nil
}
if v.UsesLengthInHeader() {
utils.WriteVarInt(b, uint64(h.PayloadLen))
}
if v.UsesVarintPacketNumbers() {
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
if len(h.DiversificationNonce) != 32 {
return errors.New("invalid diversification nonce length")
}
b.Write(h.DiversificationNonce)
}
return nil
utils.WriteVarInt(b, uint64(h.PayloadLen))
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
typeByte := byte(0x30)
typeByte |= byte(h.KeyPhase << 6)
if !v.UsesVarintPacketNumbers() {
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
case protocol.PacketNumberLen2:
typeByte |= 0x1
case protocol.PacketNumberLen4:
typeByte |= 0x2
default:
return errInvalidPacketNumberLen
}
}
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
if !v.UsesVarintPacketNumbers() {
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
}
return nil
}
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
// writePublicHeader writes a Public Header.
func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) {
return errors.New("PublicHeader: Can only write regular packets")
}
if h.SrcConnectionID.Len() != 0 {
return errors.New("PublicHeader: SrcConnectionID must not be set")
}
if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
}
publicFlagByte := uint8(0x00)
if h.VersionFlag {
publicFlagByte |= 0x01
}
if h.DestConnectionID.Len() > 0 {
publicFlagByte |= 0x08
}
if len(h.DiversificationNonce) > 0 {
if len(h.DiversificationNonce) != 32 {
return errors.New("invalid diversification nonce length")
}
publicFlagByte |= 0x04
}
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
publicFlagByte |= 0x00
case protocol.PacketNumberLen2:
publicFlagByte |= 0x10
case protocol.PacketNumberLen4:
publicFlagByte |= 0x20
}
b.WriteByte(publicFlagByte)
if h.DestConnectionID.Len() > 0 {
b.Write(h.DestConnectionID)
}
if h.VersionFlag && pers == protocol.PerspectiveClient {
utils.BigEndian.WriteUint32(b, uint32(h.Version))
}
if len(h.DiversificationNonce) > 0 {
b.Write(h.DiversificationNonce)
}
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
case protocol.PacketNumberLen6:
return errInvalidPacketNumberLen
default:
return errors.New("PublicHeader: PacketNumberLen not set")
}
return nil
}
// GetLength determines the length of the Header.
func (h *Header) GetLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
if !v.UsesIETFHeaderFormat() {
return h.getPublicHeaderLength()
}
return h.getHeaderLength(v)
}
func (h *Header) getHeaderLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount {
if h.IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen)
if v.UsesLengthInHeader() {
length += utils.VarIntLen(uint64(h.PayloadLen))
}
if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.PayloadLen))
if h.Type == protocol.PacketTypeInitial {
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
}
if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
length += protocol.ByteCount(len(h.DiversificationNonce))
}
return length, nil
return length
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
length += protocol.ByteCount(h.PacketNumberLen)
return length, nil
}
// getPublicHeaderLength gets the length of the publicHeader in bytes.
// It can only be called for regular packets.
func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
length := protocol.ByteCount(1) // 1 byte for public flags
if h.PacketNumberLen == protocol.PacketNumberLen6 {
return 0, errInvalidPacketNumberLen
}
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
return 0, errPacketNumberLenNotSet
}
length += protocol.ByteCount(h.PacketNumberLen)
length += protocol.ByteCount(h.DestConnectionID.Len())
// Version Number in packets sent by the client
if h.VersionFlag {
length += 4
}
length += protocol.ByteCount(len(h.DiversificationNonce))
return length, nil
return length
}
// Log logs the Header
func (h *Header) Log(logger utils.Logger) {
if h.IsPublicHeader {
h.logPublicHeader(logger)
} else {
h.logHeader(logger)
}
}
func (h *Header) logHeader(logger utils.Logger) {
if h.IsLongHeader {
if h.Version == 0 {
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
@ -275,14 +118,6 @@ func (h *Header) logHeader(logger utils.Logger) {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
return
}
if h.Version == protocol.Version44 {
var divNonce string
if h.Type == protocol.PacketType0RTT {
divNonce = fmt.Sprintf("Diversification Nonce: %#x, ", h.DiversificationNonce)
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, divNonce, h.Version)
return
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
}
} else {
@ -290,14 +125,6 @@ func (h *Header) logHeader(logger utils.Logger) {
}
}
func (h *Header) logPublicHeader(logger utils.Logger) {
ver := "(unset)"
if h.Version != 0 {
ver = h.Version.String()
}
logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
}
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
dcil, err := encodeSingleConnIDLen(dest)
if err != nil {

View File

@ -6,8 +6,8 @@ import (
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// The InvariantHeader is the version independent part of the header
@ -32,22 +32,10 @@ func ParseInvariantHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Invariant
// If this is not a Long Header, it could either be a Public Header or a Short Header.
if !h.IsLongHeader {
// In the Public Header 0x8 is the Connection ID Flag.
// In the IETF Short Header:
// * 0x8 it is the gQUIC Demultiplexing bit, and always 0.
// * 0x20 and 0x10 are always 1.
var connIDLen int
if typeByte&0x8 > 0 { // Public Header containing a connection ID
connIDLen = 8
}
if typeByte&0x38 == 0x30 { // Short Header
connIDLen = shortHeaderConnIDLen
}
if connIDLen > 0 {
h.DestConnectionID, err = protocol.ReadConnectionID(b, connIDLen)
if err != nil {
return nil, err
}
var err error
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
if err != nil {
return nil, err
}
return h, nil
}
@ -81,15 +69,6 @@ func (iv *InvariantHeader) Parse(b *bytes.Reader, sentBy protocol.Perspective, v
}
return iv.parseLongHeader(b, sentBy, ver)
}
// The Public Header never uses 6 byte packet numbers.
// Therefore, the third and fourth bit will never be 11.
// For the Short Header, the third and fourth bit are always 11.
if iv.typeByte&0x30 != 0x30 {
if sentBy == protocol.PerspectiveServer && iv.typeByte&0x1 > 0 {
return iv.parseVersionNegotiationPacket(b)
}
return iv.parsePublicHeader(b, sentBy, ver)
}
return iv.parseShortHeader(b, ver)
}
@ -104,7 +83,6 @@ func (iv *InvariantHeader) toHeader() *Header {
func (iv *InvariantHeader) parseVersionNegotiationPacket(b *bytes.Reader) (*Header, error) {
h := iv.toHeader()
h.VersionFlag = true
if b.Len() == 0 {
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
}
@ -145,7 +123,7 @@ func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader, sentBy protocol.Pers
return h, nil
}
if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := utils.ReadVarInt(b)
if err != nil {
return nil, err
@ -159,37 +137,17 @@ func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader, sentBy protocol.Pers
}
}
if v.UsesLengthInHeader() {
pl, err := utils.ReadVarInt(b)
if err != nil {
return nil, err
}
h.PayloadLen = protocol.ByteCount(pl)
pl, err := utils.ReadVarInt(b)
if err != nil {
return nil, err
}
if v.UsesVarintPacketNumbers() {
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
h.PacketNumber = pn
h.PacketNumberLen = pnLen
} else {
pn, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return nil, err
}
h.PacketNumber = protocol.PacketNumber(pn)
h.PacketNumberLen = protocol.PacketNumberLen4
}
if h.Type == protocol.PacketType0RTT && v == protocol.Version44 && sentBy == protocol.PerspectiveServer {
h.DiversificationNonce = make([]byte, 32)
if _, err := io.ReadFull(b, h.DiversificationNonce); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
}
h.PayloadLen = protocol.ByteCount(pl)
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
h.PacketNumber = pn
h.PacketNumberLen = pnLen
return h, nil
}
@ -198,76 +156,12 @@ func (iv *InvariantHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionN
h := iv.toHeader()
h.KeyPhase = int(iv.typeByte&0x40) >> 6
if v.UsesVarintPacketNumbers() {
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
h.PacketNumber = pn
h.PacketNumberLen = pnLen
} else {
switch iv.typeByte & 0x3 {
case 0x0:
h.PacketNumberLen = protocol.PacketNumberLen1
case 0x1:
h.PacketNumberLen = protocol.PacketNumberLen2
case 0x2:
h.PacketNumberLen = protocol.PacketNumberLen4
default:
return nil, errInvalidPacketNumberLen
}
p, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
if err != nil {
return nil, err
}
h.PacketNumber = protocol.PacketNumber(p)
}
return h, nil
}
func (iv *InvariantHeader) parsePublicHeader(b *bytes.Reader, sentBy protocol.Perspective, ver protocol.VersionNumber) (*Header, error) {
h := iv.toHeader()
h.IsPublicHeader = true
h.ResetFlag = iv.typeByte&0x2 > 0
if h.ResetFlag {
return h, nil
}
h.VersionFlag = iv.typeByte&0x1 > 0
if h.VersionFlag && sentBy == protocol.PerspectiveClient {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return nil, err
}
h.Version = protocol.VersionNumber(v)
}
// Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server.
// It doesn't have any meaning when sent by the client.
if sentBy == protocol.PerspectiveServer && iv.typeByte&0x4 > 0 {
h.DiversificationNonce = make([]byte, 32)
if _, err := io.ReadFull(b, h.DiversificationNonce); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
}
}
switch iv.typeByte & 0x30 {
case 0x00:
h.PacketNumberLen = protocol.PacketNumberLen1
case 0x10:
h.PacketNumberLen = protocol.PacketNumberLen2
case 0x20:
h.PacketNumberLen = protocol.PacketNumberLen4
}
pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
h.PacketNumber = protocol.PacketNumber(pn)
h.PacketNumber = pn
h.PacketNumberLen = pnLen
return h, nil
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"strings"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
@ -17,14 +18,11 @@ func LogFrame(logger utils.Logger, frame Frame, sent bool) {
dir = "->"
}
switch f := frame.(type) {
case *CryptoFrame:
dataLen := protocol.ByteCount(len(f.Data))
logger.Debugf("\t%s &wire.CryptoFrame{Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.Offset, dataLen, f.Offset+dataLen)
case *StreamFrame:
logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
case *StopWaitingFrame:
if sent {
logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
} else {
logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
}
case *AckFrame:
if len(f.AckRanges) > 1 {
ackRanges := make([]string, len(f.AckRanges))

Some files were not shown because too many files have changed in this diff Show More