mirror of https://github.com/v2ray/v2ray-core
parent
6870ead73e
commit
19926b8e4f
|
@ -5,37 +5,85 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"v2ray.com/core/transport/internet/tls"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/transport/internet"
|
||||
"v2ray.com/core/transport/internet/tls"
|
||||
)
|
||||
|
||||
type clientSessions struct {
|
||||
access sync.Mutex
|
||||
sessions map[net.Destination]quic.Session
|
||||
sessions map[net.Destination][]quic.Session
|
||||
}
|
||||
|
||||
func (s *clientSessions) getSession(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (quic.Session, error) {
|
||||
func removeInactiveSessions(sessions []quic.Session) []quic.Session {
|
||||
lastActive := 0
|
||||
for _, s := range sessions {
|
||||
active := true
|
||||
select {
|
||||
case <-s.Context().Done():
|
||||
active = false
|
||||
default:
|
||||
}
|
||||
if active {
|
||||
sessions[lastActive] = s
|
||||
lastActive++
|
||||
}
|
||||
}
|
||||
|
||||
if lastActive < len(sessions) {
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
sessions[i] = nil
|
||||
}
|
||||
sessions = sessions[:lastActive]
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
func openStream(sessions []quic.Session) (quic.Stream, net.Addr, error) {
|
||||
for _, s := range sessions {
|
||||
stream, err := s.OpenStream()
|
||||
if err != nil {
|
||||
newError("failed to create stream").Base(err).WriteToLog()
|
||||
continue
|
||||
}
|
||||
return stream, s.LocalAddr(), nil
|
||||
}
|
||||
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
|
||||
if s.sessions == nil {
|
||||
s.sessions = make(map[net.Destination]quic.Session)
|
||||
s.sessions = make(map[net.Destination][]quic.Session)
|
||||
}
|
||||
|
||||
dest := net.DestinationFromAddr(destAddr)
|
||||
|
||||
if session, found := s.sessions[dest]; found {
|
||||
select {
|
||||
case <-session.Context().Done():
|
||||
// Session has been closed. Creating a new one.
|
||||
default:
|
||||
return session, nil
|
||||
}
|
||||
var sessions []quic.Session
|
||||
if s, found := s.sessions[dest]; found {
|
||||
sessions = s
|
||||
}
|
||||
|
||||
sessions = removeInactiveSessions(sessions)
|
||||
s.sessions[dest] = sessions
|
||||
|
||||
stream, local, err := openStream(sessions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if stream != nil {
|
||||
return &interConn{
|
||||
stream: stream,
|
||||
local: local,
|
||||
remote: destAddr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
|
||||
|
@ -47,13 +95,12 @@ func (s *clientSessions) getSession(destAddr net.Addr, config *Config, tlsConfig
|
|||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
Versions: []quic.VersionNumber{quic.VersionMilestone0_10_0},
|
||||
ConnectionIDLength: 12,
|
||||
KeepAlive: true,
|
||||
HandshakeTimeout: time.Second * 4,
|
||||
IdleTimeout: time.Second * 300,
|
||||
MaxReceiveStreamFlowControlWindow: 128 * 1024,
|
||||
MaxReceiveConnectionFlowControlWindow: 512 * 1024,
|
||||
IdleTimeout: time.Second * 60,
|
||||
MaxReceiveStreamFlowControlWindow: 256 * 1024,
|
||||
MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
|
||||
MaxIncomingUniStreams: -1,
|
||||
}
|
||||
|
||||
|
@ -69,8 +116,16 @@ func (s *clientSessions) getSession(destAddr net.Addr, config *Config, tlsConfig
|
|||
return nil, err
|
||||
}
|
||||
|
||||
s.sessions[dest] = session
|
||||
return session, nil
|
||||
s.sessions[dest] = append(sessions, session)
|
||||
stream, err = session.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &interConn{
|
||||
stream: stream,
|
||||
local: session.LocalAddr(),
|
||||
remote: destAddr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var client clientSessions
|
||||
|
@ -91,21 +146,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
|
|||
|
||||
config := streamSettings.ProtocolSettings.(*Config)
|
||||
|
||||
session, err := client.getSession(destAddr, config, tlsConfig, streamSettings.SocketSettings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, err := session.OpenStreamSync()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &interConn{
|
||||
stream: conn,
|
||||
local: session.LocalAddr(),
|
||||
remote: destAddr,
|
||||
}, nil
|
||||
return client.openConnection(destAddr, config, tlsConfig, streamSettings.SocketSettings)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
@ -84,14 +84,13 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
|
|||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
Versions: []quic.VersionNumber{quic.VersionMilestone0_10_0},
|
||||
ConnectionIDLength: 12,
|
||||
KeepAlive: true,
|
||||
HandshakeTimeout: time.Second * 4,
|
||||
IdleTimeout: time.Second * 300,
|
||||
MaxReceiveStreamFlowControlWindow: 128 * 1024,
|
||||
MaxReceiveConnectionFlowControlWindow: 512 * 1024,
|
||||
MaxIncomingStreams: 256,
|
||||
IdleTimeout: time.Second * 60,
|
||||
MaxReceiveStreamFlowControlWindow: 256 * 1024,
|
||||
MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
|
||||
MaxIncomingStreams: 64,
|
||||
MaxIncomingUniStreams: -1,
|
||||
}
|
||||
|
||||
|
|
|
@ -8,12 +8,19 @@
|
|||
[![Windows Build Status](https://img.shields.io/appveyor/ci/lucas-clemente/quic-go/master.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
|
||||
[![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/)
|
||||
|
||||
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go.
|
||||
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go. It roughly implements the [IETF QUIC draft](https://github.com/quicwg/base-drafts), although we don't fully support any of the draft versions at the moment.
|
||||
|
||||
## Roadmap
|
||||
## Version compatibility
|
||||
|
||||
quic-go is compatible with the current version(s) of Google Chrome and QUIC as deployed on Google's servers. We're actively tracking the development of the Chrome code to ensure compatibility as the protocol evolves. In that process, we're dropping support for old QUIC versions.
|
||||
As Google's QUIC versions are expected to converge towards the [IETF QUIC draft](https://github.com/quicwg/base-drafts), quic-go will eventually implement that draft.
|
||||
Since quic-go is under active development, there's no guarantee that two builds of different commits are interoperable. The QUIC version used in the *master* branch is just a placeholder, and should not be considered stable.
|
||||
|
||||
If you want to use quic-go as a library in other projects, please consider using a [tagged release](https://github.com/lucas-clemente/quic-go/releases). These releases expose [experimental QUIC versions](https://github.com/quicwg/base-drafts/wiki/QUIC-Versions), which are guaranteed to be stable.
|
||||
|
||||
## Google QUIC
|
||||
|
||||
quic-go used to support both the QUIC versions supported by Google Chrome and QUIC as deployed on Google's servers, as well as IETF QUIC. Due to the divergence of the two protocols, we decided to not support both versions any more.
|
||||
|
||||
The *master* branch **only** supports IETF QUIC. For Google QUIC support, please refer to the [gquic branch](https://github.com/lucas-clemente/quic-go/tree/gquic).
|
||||
|
||||
## Guides
|
||||
|
||||
|
@ -27,31 +34,19 @@ Running tests:
|
|||
|
||||
go test ./...
|
||||
|
||||
### Running the example server
|
||||
### HTTP mapping
|
||||
|
||||
go run example/main.go -www /var/www/
|
||||
|
||||
Using the `quic_client` from chromium:
|
||||
|
||||
quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
|
||||
|
||||
Using Chrome:
|
||||
|
||||
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
|
||||
We're currently not implementing the HTTP mapping as described in the [QUIC over HTTP draft](https://quicwg.org/base-drafts/draft-ietf-quic-http.html). The HTTP mapping here is a leftover from Google QUIC.
|
||||
|
||||
### QUIC without HTTP/2
|
||||
|
||||
Take a look at [this echo example](example/echo/echo.go).
|
||||
|
||||
### Using the example client
|
||||
|
||||
go run example/client/main.go https://clemente.io
|
||||
|
||||
## Usage
|
||||
|
||||
### As a server
|
||||
|
||||
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
|
||||
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go:
|
||||
|
||||
```go
|
||||
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
|
||||
|
|
|
@ -10,6 +10,9 @@ environment:
|
|||
- GOARCH: 386
|
||||
- GOARCH: amd64
|
||||
|
||||
hosts:
|
||||
quic.clemente.io: 127.0.0.1
|
||||
|
||||
clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
|
||||
|
||||
install:
|
||||
|
@ -19,7 +22,6 @@ install:
|
|||
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
|
||||
- echo %PATH%
|
||||
- echo %GOPATH%
|
||||
- git submodule update --init --recursive
|
||||
- go get github.com/onsi/ginkgo/ginkgo
|
||||
- go get github.com/onsi/gomega
|
||||
- go version
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
|
@ -9,12 +8,11 @@ import (
|
|||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
|
@ -25,29 +23,25 @@ type client struct {
|
|||
// If it is started with Dial, we take a packet conn as a parameter.
|
||||
createdPacketConn bool
|
||||
|
||||
hostname string
|
||||
|
||||
packetHandlers packetHandlerManager
|
||||
|
||||
token []byte
|
||||
numRetries int
|
||||
token []byte
|
||||
|
||||
versionNegotiated bool // has the server accepted our version
|
||||
receivedVersionNegotiationPacket bool
|
||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||
|
||||
tlsConf *tls.Config
|
||||
mintConf *mint.Config
|
||||
config *Config
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry)
|
||||
|
||||
initialVersion protocol.VersionNumber
|
||||
version protocol.VersionNumber
|
||||
|
||||
handshakeChan chan struct{}
|
||||
closeCallback func(protocol.ConnectionID)
|
||||
|
||||
session quicSession
|
||||
|
||||
|
@ -128,18 +122,11 @@ func dialContext(
|
|||
createdPacketConn bool,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
if !createdPacketConn {
|
||||
for _, v := range config.Versions {
|
||||
if v == protocol.Version44 {
|
||||
return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
|
||||
}
|
||||
}
|
||||
}
|
||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, createdPacketConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -156,16 +143,14 @@ func newClient(
|
|||
config *Config,
|
||||
tlsConf *tls.Config,
|
||||
host string,
|
||||
closeCallback func(protocol.ConnectionID),
|
||||
createdPacketConn bool,
|
||||
) (*client, error) {
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
}
|
||||
if hostname == "" {
|
||||
if tlsConf.ServerName == "" {
|
||||
var err error
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
tlsConf.ServerName, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -179,19 +164,13 @@ func newClient(
|
|||
}
|
||||
}
|
||||
}
|
||||
onClose := func(protocol.ConnectionID) {}
|
||||
if closeCallback != nil {
|
||||
onClose = closeCallback
|
||||
}
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
createdPacketConn: createdPacketConn,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
closeCallback: onClose,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
return c, c.generateConnectionIDs()
|
||||
|
@ -219,11 +198,11 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
|||
|
||||
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
||||
if maxReceiveStreamFlowControlWindow == 0 {
|
||||
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
|
||||
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
|
||||
}
|
||||
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
|
||||
if maxReceiveConnectionFlowControlWindow == 0 {
|
||||
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
|
||||
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
maxIncomingStreams := config.MaxIncomingStreams
|
||||
if maxIncomingStreams == 0 {
|
||||
|
@ -241,17 +220,11 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
|||
if connIDLen == 0 && !createdPacketConn {
|
||||
connIDLen = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
for _, v := range versions {
|
||||
if v == protocol.Version44 {
|
||||
connIDLen = 0
|
||||
}
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Versions: versions,
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
IdleTimeout: idleTimeout,
|
||||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||
ConnectionIDLength: connIDLen,
|
||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
|
@ -262,75 +235,26 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
|||
}
|
||||
|
||||
func (c *client) generateConnectionIDs() error {
|
||||
connIDLen := protocol.ConnectionIDLenGQUIC
|
||||
if c.version.UsesTLS() {
|
||||
connIDLen = c.config.ConnectionIDLength
|
||||
}
|
||||
srcConnID, err := generateConnectionID(connIDLen)
|
||||
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destConnID := srcConnID
|
||||
if c.version.UsesTLS() {
|
||||
destConnID, err = generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destConnID, err := generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.srcConnID = srcConnID
|
||||
c.destConnID = destConnID
|
||||
if c.version == protocol.Version44 {
|
||||
c.srcConnID = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
var err error
|
||||
if c.version.UsesTLS() {
|
||||
err = c.dialTLS(ctx)
|
||||
} else {
|
||||
err = c.dialGQUIC(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialGQUIC(ctx context.Context) error {
|
||||
if err := c.createNewGQUICSession(); err != nil {
|
||||
if err := c.createNewTLSSession(c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
err := c.establishSecureConnection(ctx)
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialTLS(ctx context.Context) error {
|
||||
params := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
IdleTimeout: c.config.IdleTimeout,
|
||||
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
|
||||
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
|
||||
DisableMigration: true,
|
||||
}
|
||||
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
|
||||
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mintConf.ExtensionHandler = extHandler
|
||||
mintConf.ServerName = c.hostname
|
||||
c.mintConf = mintConf
|
||||
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.establishSecureConnection(ctx)
|
||||
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
|
||||
return c.dial(ctx)
|
||||
}
|
||||
|
@ -340,9 +264,9 @@ func (c *client) dialTLS(ctx context.Context) error {
|
|||
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||
// It returns:
|
||||
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
|
||||
// - any other error that might occur
|
||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||
// - when the connection is forward-secure
|
||||
func (c *client) establishSecureConnection(ctx context.Context) error {
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
|
@ -387,35 +311,14 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if !c.version.UsesIETFHeaderFormat() {
|
||||
connID := p.header.DestConnectionID
|
||||
// reject packets with truncated connection id if we didn't request truncation
|
||||
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
|
||||
return errors.New("received packet with truncated connection ID, but didn't request truncation")
|
||||
}
|
||||
// reject packets with the wrong connection ID
|
||||
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
|
||||
}
|
||||
if p.header.ResetFlag {
|
||||
return c.handlePublicReset(p)
|
||||
}
|
||||
} else {
|
||||
// reject packets with the wrong connection ID
|
||||
if !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
|
||||
}
|
||||
// reject packets with the wrong connection ID
|
||||
if !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
|
||||
}
|
||||
|
||||
if p.header.IsLongHeader {
|
||||
switch p.header.Type {
|
||||
case protocol.PacketTypeRetry:
|
||||
c.handleRetryPacket(p.header)
|
||||
return nil
|
||||
case protocol.PacketTypeHandshake, protocol.PacketType0RTT:
|
||||
default:
|
||||
return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
|
||||
}
|
||||
if p.header.Type == protocol.PacketTypeRetry {
|
||||
c.handleRetryPacket(p.header)
|
||||
return nil
|
||||
}
|
||||
|
||||
// this is the first packet we are receiving
|
||||
|
@ -428,22 +331,6 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handlePublicReset(p *receivedPacket) error {
|
||||
cr := c.conn.RemoteAddr()
|
||||
// check if the remote address and the connection ID match
|
||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
||||
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return errors.New("Received a spoofed Public Reset")
|
||||
}
|
||||
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
}
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||
|
@ -483,77 +370,56 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||
func (c *client) handleRetryPacket(hdr *wire.Header) {
|
||||
c.logger.Debugf("<- Received Retry")
|
||||
hdr.Log(c.logger)
|
||||
// A server that performs multiple retries must use a source connection ID of at least 8 bytes.
|
||||
// Only a server that won't send additional Retries can use shorter connection IDs.
|
||||
if hdr.OrigDestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
c.logger.Debugf("Received a Retry with a too short Original Destination Connection ID: %d bytes, must have at least %d bytes.", hdr.OrigDestConnectionID.Len(), protocol.MinConnectionIDLenInitial)
|
||||
return
|
||||
}
|
||||
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
|
||||
c.logger.Debugf("Received spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
|
||||
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
|
||||
return
|
||||
}
|
||||
c.numRetries++
|
||||
if c.numRetries > protocol.MaxRetries {
|
||||
c.session.destroy(qerr.CryptoTooManyRejects)
|
||||
if hdr.SrcConnectionID.Equal(c.destConnID) {
|
||||
c.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
|
||||
return
|
||||
}
|
||||
// If a token is already set, this means that we already received a Retry from the server.
|
||||
// Ignore this Retry packet.
|
||||
if len(c.token) > 0 {
|
||||
c.logger.Debugf("Ignoring Retry, since a Retry was already received.")
|
||||
return
|
||||
}
|
||||
c.origDestConnID = c.destConnID
|
||||
c.destConnID = hdr.SrcConnectionID
|
||||
c.token = hdr.Token
|
||||
c.session.destroy(errCloseSessionForRetry)
|
||||
}
|
||||
|
||||
func (c *client) createNewGQUICSession() error {
|
||||
func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
|
||||
params := &handshake.TransportParameters{
|
||||
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
||||
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
||||
InitialMaxStreamDataUni: protocol.InitialMaxStreamData,
|
||||
InitialMaxData: protocol.InitialMaxData,
|
||||
IdleTimeout: c.config.IdleTimeout,
|
||||
MaxBidiStreams: uint64(c.config.MaxIncomingStreams),
|
||||
MaxUniStreams: uint64(c.config.MaxIncomingUniStreams),
|
||||
DisableMigration: true,
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: c.closeCallback,
|
||||
retireConnectionIDImpl: c.packetHandlers.Retire,
|
||||
removeConnectionIDImpl: c.packetHandlers.Remove,
|
||||
}
|
||||
sess, err := newClientSession(
|
||||
c.conn,
|
||||
runner,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.tlsConf,
|
||||
c.config,
|
||||
c.initialVersion,
|
||||
c.negotiatedVersions,
|
||||
c.logger,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.session = sess
|
||||
c.packetHandlers.Add(c.srcConnID, c)
|
||||
if c.config.RequestConnectionIDOmission {
|
||||
c.packetHandlers.Add(protocol.ConnectionID{}, c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) createNewTLSSession(
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
version protocol.VersionNumber,
|
||||
) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: c.closeCallback,
|
||||
}
|
||||
sess, err := newTLSClientSession(
|
||||
c.conn,
|
||||
runner,
|
||||
c.token,
|
||||
c.origDestConnID,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.config,
|
||||
c.mintConf,
|
||||
paramsChan,
|
||||
1,
|
||||
c.tlsConf,
|
||||
params,
|
||||
c.initialVersion,
|
||||
c.logger,
|
||||
c.version,
|
||||
)
|
||||
|
|
|
@ -1,41 +1,108 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStream interface {
|
||||
StreamID() protocol.StreamID
|
||||
io.Reader
|
||||
// for receiving data
|
||||
HandleCryptoFrame(*wire.CryptoFrame) error
|
||||
GetCryptoData() []byte
|
||||
Finish() error
|
||||
// for sending data
|
||||
io.Writer
|
||||
handleStreamFrame(*wire.StreamFrame) error
|
||||
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||
closeForShutdown(error)
|
||||
setReadOffset(protocol.ByteCount)
|
||||
// methods needed for flow control
|
||||
getWindowUpdate() protocol.ByteCount
|
||||
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||
HasData() bool
|
||||
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
|
||||
}
|
||||
|
||||
type cryptoStreamImpl struct {
|
||||
*stream
|
||||
queue *frameSorter
|
||||
msgBuf []byte
|
||||
|
||||
highestOffset protocol.ByteCount
|
||||
finished bool
|
||||
|
||||
writeOffset protocol.ByteCount
|
||||
writeBuf []byte
|
||||
}
|
||||
|
||||
var _ cryptoStream = &cryptoStreamImpl{}
|
||||
|
||||
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream {
|
||||
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
||||
return &cryptoStreamImpl{str}
|
||||
func newCryptoStream() cryptoStream {
|
||||
return &cryptoStreamImpl{
|
||||
queue: newFrameSorter(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetReadOffset sets the read offset.
|
||||
// It is only needed for the crypto stream.
|
||||
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||
func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) {
|
||||
s.receiveStream.readOffset = offset
|
||||
s.receiveStream.frameQueue.readPos = offset
|
||||
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
||||
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
|
||||
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
|
||||
return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset)
|
||||
}
|
||||
if s.finished {
|
||||
if highestOffset > s.highestOffset {
|
||||
// reject crypto data received after this stream was already finished
|
||||
return errors.New("received crypto data after change of encryption level")
|
||||
}
|
||||
// ignore data with a smaller offset than the highest received
|
||||
// could e.g. be a retransmission
|
||||
return nil
|
||||
}
|
||||
s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset)
|
||||
if err := s.queue.Push(f.Data, f.Offset, false); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
data, _ := s.queue.Pop()
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
s.msgBuf = append(s.msgBuf, data...)
|
||||
}
|
||||
}
|
||||
|
||||
// GetCryptoData retrieves data that was received in CRYPTO frames
|
||||
func (s *cryptoStreamImpl) GetCryptoData() []byte {
|
||||
if len(s.msgBuf) < 4 {
|
||||
return nil
|
||||
}
|
||||
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
|
||||
if len(s.msgBuf) < msgLen {
|
||||
return nil
|
||||
}
|
||||
msg := make([]byte, msgLen)
|
||||
copy(msg, s.msgBuf[:msgLen])
|
||||
s.msgBuf = s.msgBuf[msgLen:]
|
||||
return msg
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) Finish() error {
|
||||
if s.queue.HasMoreData() {
|
||||
return errors.New("encryption level changed, but crypto stream has more data to read")
|
||||
}
|
||||
s.finished = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Writes writes data that should be sent out in CRYPTO frames
|
||||
func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
|
||||
s.writeBuf = append(s.writeBuf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) HasData() bool {
|
||||
return len(s.writeBuf) > 0
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
||||
f := &wire.CryptoFrame{Offset: s.writeOffset}
|
||||
n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
|
||||
f.Data = s.writeBuf[:n]
|
||||
s.writeBuf = s.writeBuf[n:]
|
||||
s.writeOffset += n
|
||||
return f
|
||||
}
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoDataHandler interface {
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
}
|
||||
|
||||
type cryptoStreamManager struct {
|
||||
cryptoHandler cryptoDataHandler
|
||||
|
||||
initialStream cryptoStream
|
||||
handshakeStream cryptoStream
|
||||
}
|
||||
|
||||
func newCryptoStreamManager(
|
||||
cryptoHandler cryptoDataHandler,
|
||||
initialStream cryptoStream,
|
||||
handshakeStream cryptoStream,
|
||||
) *cryptoStreamManager {
|
||||
return &cryptoStreamManager{
|
||||
cryptoHandler: cryptoHandler,
|
||||
initialStream: initialStream,
|
||||
handshakeStream: handshakeStream,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
|
||||
var str cryptoStream
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
str = m.initialStream
|
||||
case protocol.EncryptionHandshake:
|
||||
str = m.handshakeStream
|
||||
default:
|
||||
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||
}
|
||||
if err := str.HandleCryptoFrame(frame); err != nil {
|
||||
return false, err
|
||||
}
|
||||
for {
|
||||
data := str.GetCryptoData()
|
||||
if data == nil {
|
||||
return false, nil
|
||||
}
|
||||
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
|
||||
return true, str.Finish()
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
Before Width: | Height: | Size: 17 KiB |
Binary file not shown.
|
@ -1,65 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<svg width="601px" height="242px" viewBox="0 0 601 242" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<!-- Generator: Sketch 3.7.2 (28276) - http://www.bohemiancoding.com/sketch -->
|
||||
<title>quic</title>
|
||||
<desc>Created with Sketch.</desc>
|
||||
<defs></defs>
|
||||
<g id="Page-1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<g id="Group-4-Copy" transform="translate(586.087964, 15.755663) rotate(-26.000000) translate(-586.087964, -15.755663) translate(573.036290, 1.253803)">
|
||||
<path d="M12.0494985,28.9428124 C-12.9338918,21.9172827 5.64340536,-9.86582259 25.705693,3.20698278 L12.0494985,28.9428124 L12.0494985,28.9428124 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.861962, 14.514517) scale(-1, 1) translate(-12.861962, -14.514517) "></path>
|
||||
<path d="M12.1335889,20.1646293 C8.83752326,18.4210914 6.41944167,16.0531448 8.43561456,12.2408129 C10.3022111,8.71147453 13.769813,9.09473797 17.066293,10.8382759 L12.1335889,20.1646293 L12.1335889,20.1646293 Z" id="Shape" fill="#000000" transform="translate(12.382769, 14.854037) scale(-1, 1) translate(-12.382769, -14.854037) "></path>
|
||||
</g>
|
||||
<g id="Group-3-Copy" transform="translate(499.346306, 15.541651) rotate(25.000000) translate(-499.346306, -15.541651) translate(486.346306, 0.541651)">
|
||||
<path d="M0.262217125,3.72242032 C20.0373679,-10.1674613 37.9073886,21.170644 14.9248407,29.0555125 L0.262217125,3.72242032 L0.262217125,3.72242032 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.997366, 14.636187) scale(-1, 1) translate(-12.997366, -14.636187) "></path>
|
||||
<path d="M16.7752477,19.7502904 C20.0713133,18.0067525 22.4893949,15.6388059 20.473222,11.8264741 C18.6066255,8.29713567 15.1390235,8.68039911 11.8425436,10.423937 L16.7752477,19.7502904 L16.7752477,19.7502904 Z" id="Shape" fill="#000000" transform="translate(16.526068, 14.439699) scale(-1, 1) translate(-16.526068, -14.439699) "></path>
|
||||
</g>
|
||||
<g id="Group-3" transform="translate(1.896652, 21.000000)">
|
||||
<path d="M0.262217125,3.72242032 C20.0373679,-10.1674613 37.9073886,21.170644 14.9248407,29.0555125 L0.262217125,3.72242032 L0.262217125,3.72242032 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.997366, 14.636187) scale(-1, 1) translate(-12.997366, -14.636187) "></path>
|
||||
<path d="M16.7752477,19.7502904 C20.0713133,18.0067525 22.4893949,15.6388059 20.473222,11.8264741 C18.6066255,8.29713567 15.1390235,8.68039911 11.8425436,10.423937 L16.7752477,19.7502904 L16.7752477,19.7502904 Z" id="Shape" fill="#000000" transform="translate(16.526068, 14.439699) scale(-1, 1) translate(-16.526068, -14.439699) "></path>
|
||||
</g>
|
||||
<g id="Group-4" transform="translate(117.000000, 22.000000)">
|
||||
<path d="M12.0494985,28.9428124 C-12.9338918,21.9172827 5.64340536,-9.86582259 25.705693,3.20698278 L12.0494985,28.9428124 L12.0494985,28.9428124 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.861962, 14.514517) scale(-1, 1) translate(-12.861962, -14.514517) "></path>
|
||||
<path d="M12.1335889,20.1646293 C8.83752326,18.4210914 6.41944167,16.0531448 8.43561456,12.2408129 C10.3022111,8.71147453 13.769813,9.09473797 17.066293,10.8382759 L12.1335889,20.1646293 L12.1335889,20.1646293 Z" id="Shape" fill="#000000" transform="translate(12.382769, 14.854037) scale(-1, 1) translate(-12.382769, -14.854037) "></path>
|
||||
</g>
|
||||
<g id="Group-2" transform="translate(26.000000, 40.000000)">
|
||||
<path d="M0.189217386,22.0262538 C4.67982195,48.130845 47.3500951,41.2287883 41.2211947,14.9008684 C35.725404,-8.70815982 -1.30323119,-2.17320732 0.189217386,22.0262538" id="Shape" stroke="#000000" stroke-width="2.9081" fill="#FFFFFF" transform="translate(20.983709, 19.323749) scale(-1, 1) translate(-20.983709, -19.323749) "></path>
|
||||
<path d="M52.375353,26.2483668 C58.1955709,48.9748532 94.5815666,43.1562926 93.210105,20.3593686 C91.566837,-6.94349061 46.8107821,-1.67517201 52.375353,26.2483668" id="Shape" stroke="#000000" stroke-width="2.8214" fill="#FFFFFF" transform="translate(72.572290, 21.519163) scale(-1, 1) translate(-72.572290, -21.519163) "></path>
|
||||
<path d="M44.873555,53.3511003 C44.8926146,56.7449499 45.6446396,60.55521 45.0028287,64.1657589 C44.1364461,65.803226 42.4368281,65.9764197 40.9717259,66.6381188 C38.9456089,66.3203209 37.2418475,64.9898788 36.429329,63.0946929 C35.9093337,58.9736786 36.6232396,54.9835954 36.754585,50.8609237 L44.873555,53.3511003 L44.873555,53.3511003 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(40.751798, 58.749521) scale(-1, 1) translate(-40.751798, -58.749521) "></path>
|
||||
<g id="Group" transform="translate(82.660603, 22.581468) scale(-1, 1) translate(-82.660603, -22.581468) translate(76.238350, 15.744877)">
|
||||
<ellipse id="Oval" fill="#000000" cx="6.34932869" cy="6.79681466" rx="6.14423095" ry="6.65511077"></ellipse>
|
||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.1440443" cy="8.29879302" rx="1.44852865" ry="1.69133123"></ellipse>
|
||||
</g>
|
||||
<g id="Group" transform="translate(31.282584, 20.924112) scale(-1, 1) translate(-31.282584, -20.924112) translate(24.860332, 14.087521)">
|
||||
<ellipse id="Oval" fill="#000000" cx="6.45167039" cy="6.79681466" rx="6.04188925" ry="6.65511077"></ellipse>
|
||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.19998004" cy="8.29879302" rx="1.424497" ry="1.69133123"></ellipse>
|
||||
</g>
|
||||
<path d="M45.9112876,52.98151 C43.2305151,59.4783433 47.4062222,72.4699383 54.6799409,62.8875235 C54.1599456,58.7665092 54.8738514,54.776426 55.0051969,50.6537543 L45.9112876,52.98151 L45.9112876,52.98151 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(50.049883, 58.475140) scale(-1, 1) translate(-50.049883, -58.475140) "></path>
|
||||
<g id="Group" transform="translate(44.127089, 43.712750) scale(-1, 1) translate(-44.127089, -43.712750) translate(27.346365, 32.732770)">
|
||||
<path d="M7.63667953,7.73694953 C2.64016722,8.16288988 -1.44397093,14.1036805 1.15393372,18.8035261 C4.59418928,25.0285532 12.2731314,18.2528698 17.0558448,18.8876369 C22.5603366,19.0003371 27.0704151,24.7078549 31.4914107,19.9193407 C36.40837,14.593429 29.3745535,9.4063208 23.8771055,7.0872662 L7.63667953,7.73694953 L7.63667953,7.73694953 Z" id="Shape" stroke="#231F20" stroke-width="3" fill="#F6D2A2"></path>
|
||||
<path d="M7.00771314,7.47674473 C6.63770854,-1.1792084 23.1412397,-2.2614615 25.0902897,4.98408215 C27.0343676,12.2126379 7.82023164,13.891539 7.00771314,7.47674473 C6.35927282,2.3542734 7.00771314,7.47674473 7.00771314,7.47674473 L7.00771314,7.47674473 Z" id="Shape" fill="#000000"></path>
|
||||
</g>
|
||||
</g>
|
||||
<g id="Group-2-Copy" transform="translate(544.820291, 73.354278) scale(-1, 1) translate(-544.820291, -73.354278) translate(498.000000, 40.000000)">
|
||||
<path d="M0.189217386,22.0262538 C4.67982195,48.130845 47.3500951,41.2287883 41.2211947,14.9008684 C35.725404,-8.70815982 -1.30323119,-2.17320732 0.189217386,22.0262538" id="Shape" stroke="#000000" stroke-width="2.9081" fill="#FFFFFF" transform="translate(20.983709, 19.323749) scale(-1, 1) translate(-20.983709, -19.323749) "></path>
|
||||
<path d="M52.375353,26.2483668 C58.1955709,48.9748532 94.5815666,43.1562926 93.210105,20.3593686 C91.566837,-6.94349061 46.8107821,-1.67517201 52.375353,26.2483668" id="Shape" stroke="#000000" stroke-width="2.8214" fill="#FFFFFF" transform="translate(72.572290, 21.519163) scale(-1, 1) translate(-72.572290, -21.519163) "></path>
|
||||
<path d="M44.873555,53.3511003 C44.8926146,56.7449499 45.6446396,60.55521 45.0028287,64.1657589 C44.1364461,65.803226 42.4368281,65.9764197 40.9717259,66.6381188 C38.9456089,66.3203209 37.2418475,64.9898788 36.429329,63.0946929 C35.9093337,58.9736786 36.6232396,54.9835954 36.754585,50.8609237 L44.873555,53.3511003 L44.873555,53.3511003 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(40.751798, 58.749521) scale(-1, 1) translate(-40.751798, -58.749521) "></path>
|
||||
<g id="Group" transform="translate(82.660603, 22.581468) scale(-1, 1) translate(-82.660603, -22.581468) translate(76.238350, 15.744877)">
|
||||
<ellipse id="Oval" fill="#000000" cx="6.34932869" cy="6.79681466" rx="6.14423095" ry="6.65511077"></ellipse>
|
||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.1440443" cy="8.29879302" rx="1.44852865" ry="1.69133123"></ellipse>
|
||||
</g>
|
||||
<g id="Group" transform="translate(31.282584, 20.924112) scale(-1, 1) translate(-31.282584, -20.924112) translate(24.860332, 14.087521)">
|
||||
<ellipse id="Oval" fill="#000000" cx="6.45167039" cy="6.79681466" rx="6.04188925" ry="6.65511077"></ellipse>
|
||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.19998004" cy="8.29879302" rx="1.424497" ry="1.69133123"></ellipse>
|
||||
</g>
|
||||
<path d="M45.9112876,52.98151 C43.2305151,59.4783433 47.4062222,72.4699383 54.6799409,62.8875235 C54.1599456,58.7665092 54.8738514,54.776426 55.0051969,50.6537543 L45.9112876,52.98151 L45.9112876,52.98151 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(50.049883, 58.475140) scale(-1, 1) translate(-50.049883, -58.475140) "></path>
|
||||
<g id="Group" transform="translate(44.127089, 43.712750) scale(-1, 1) translate(-44.127089, -43.712750) translate(27.346365, 32.732770)">
|
||||
<path d="M7.63667953,7.73694953 C2.64016722,8.16288988 -1.44397093,14.1036805 1.15393372,18.8035261 C4.59418928,25.0285532 12.2731314,18.2528698 17.0558448,18.8876369 C22.5603366,19.0003371 27.0704151,24.7078549 31.4914107,19.9193407 C36.40837,14.593429 29.3745535,9.4063208 23.8771055,7.0872662 L7.63667953,7.73694953 L7.63667953,7.73694953 Z" id="Shape" stroke="#231F20" stroke-width="3" fill="#F6D2A2"></path>
|
||||
<path d="M7.00771314,7.47674473 C6.63770854,-1.1792084 23.1412397,-2.2614615 25.0902897,4.98408215 C27.0343676,12.2126379 7.82023164,13.891539 7.00771314,7.47674473 C6.35927282,2.3542734 7.00771314,7.47674473 7.00771314,7.47674473 L7.00771314,7.47674473 Z" id="Shape" fill="#000000"></path>
|
||||
</g>
|
||||
</g>
|
||||
<text id="QUIC" font-family="SourceCodePro-Light, Source Code Pro" font-size="288" font-weight="300" letter-spacing="-20" fill="#000000">
|
||||
<tspan x="-13.6" y="197">QUI</tspan>
|
||||
<tspan x="444.8" y="197">C</tspan>
|
||||
</text>
|
||||
</g>
|
||||
</svg>
|
Before Width: | Height: | Size: 11 KiB |
|
@ -156,3 +156,8 @@ func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
|
|||
s.readPos += protocol.ByteCount(len(data))
|
||||
return data, s.readPos >= s.finalOffset
|
||||
}
|
||||
|
||||
// HasMoreData says if there is any more data queued at *any* offset.
|
||||
func (s *frameSorter) HasMoreData() bool {
|
||||
return len(s.queue) > 0
|
||||
}
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type framer interface {
|
||||
QueueControlFrame(wire.Frame)
|
||||
AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount)
|
||||
|
||||
AddActiveStream(protocol.StreamID)
|
||||
AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame
|
||||
}
|
||||
|
||||
type framerI struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
streamGetter streamGetter
|
||||
version protocol.VersionNumber
|
||||
|
||||
activeStreams map[protocol.StreamID]struct{}
|
||||
streamQueue []protocol.StreamID
|
||||
|
||||
controlFrameMutex sync.Mutex
|
||||
controlFrames []wire.Frame
|
||||
}
|
||||
|
||||
var _ framer = &framerI{}
|
||||
|
||||
func newFramer(
|
||||
streamGetter streamGetter,
|
||||
v protocol.VersionNumber,
|
||||
) framer {
|
||||
return &framerI{
|
||||
streamGetter: streamGetter,
|
||||
activeStreams: make(map[protocol.StreamID]struct{}),
|
||||
version: v,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *framerI) QueueControlFrame(frame wire.Frame) {
|
||||
f.controlFrameMutex.Lock()
|
||||
f.controlFrames = append(f.controlFrames, frame)
|
||||
f.controlFrameMutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *framerI) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) {
|
||||
var length protocol.ByteCount
|
||||
f.controlFrameMutex.Lock()
|
||||
for len(f.controlFrames) > 0 {
|
||||
frame := f.controlFrames[len(f.controlFrames)-1]
|
||||
frameLen := frame.Length(f.version)
|
||||
if length+frameLen > maxLen {
|
||||
break
|
||||
}
|
||||
frames = append(frames, frame)
|
||||
length += frameLen
|
||||
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
|
||||
}
|
||||
f.controlFrameMutex.Unlock()
|
||||
return frames, length
|
||||
}
|
||||
|
||||
func (f *framerI) AddActiveStream(id protocol.StreamID) {
|
||||
f.mutex.Lock()
|
||||
if _, ok := f.activeStreams[id]; !ok {
|
||||
f.streamQueue = append(f.streamQueue, id)
|
||||
f.activeStreams[id] = struct{}{}
|
||||
}
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame {
|
||||
var length protocol.ByteCount
|
||||
f.mutex.Lock()
|
||||
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
|
||||
numActiveStreams := len(f.streamQueue)
|
||||
for i := 0; i < numActiveStreams; i++ {
|
||||
if maxLen-length < protocol.MinStreamFrameSize {
|
||||
break
|
||||
}
|
||||
id := f.streamQueue[0]
|
||||
f.streamQueue = f.streamQueue[1:]
|
||||
// This should never return an error. Better check it anyway.
|
||||
// The stream will only be in the streamQueue, if it enqueued itself there.
|
||||
str, err := f.streamGetter.GetOrOpenSendStream(id)
|
||||
// The stream can be nil if it completed after it said it had data.
|
||||
if str == nil || err != nil {
|
||||
delete(f.activeStreams, id)
|
||||
continue
|
||||
}
|
||||
frame, hasMoreData := str.popStreamFrame(maxLen - length)
|
||||
if hasMoreData { // put the stream back in the queue (at the end)
|
||||
f.streamQueue = append(f.streamQueue, id)
|
||||
} else { // no more data to send. Stream is not active any more
|
||||
delete(f.activeStreams, id)
|
||||
}
|
||||
if frame == nil { // can happen if the receiveStream was canceled after it said it had data
|
||||
continue
|
||||
}
|
||||
frames = append(frames, frame)
|
||||
length += frame.Length(f.version)
|
||||
}
|
||||
f.mutex.Unlock()
|
||||
return frames
|
||||
}
|
|
@ -1,314 +0,0 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"golang.org/x/net/idna"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type roundTripperOpts struct {
|
||||
DisableCompression bool
|
||||
}
|
||||
|
||||
var dialAddr = quic.DialAddr
|
||||
|
||||
// client is a HTTP2 client doing QUIC requests
|
||||
type client struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *quic.Config
|
||||
opts *roundTripperOpts
|
||||
|
||||
hostname string
|
||||
handshakeErr error
|
||||
dialOnce sync.Once
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
|
||||
session quic.Session
|
||||
headerStream quic.Stream
|
||||
headerErr *qerr.QuicError
|
||||
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
|
||||
requestWriter *requestWriter
|
||||
|
||||
responses map[protocol.StreamID]chan *http.Response
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = &client{}
|
||||
|
||||
var defaultQuicConfig = &quic.Config{
|
||||
RequestConnectionIDOmission: true,
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
// newClient creates a new client
|
||||
func newClient(
|
||||
hostname string,
|
||||
tlsConfig *tls.Config,
|
||||
opts *roundTripperOpts,
|
||||
quicConfig *quic.Config,
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
||||
) *client {
|
||||
config := defaultQuicConfig
|
||||
if quicConfig != nil {
|
||||
config = quicConfig
|
||||
}
|
||||
return &client{
|
||||
hostname: authorityAddr("https", hostname),
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
tlsConf: tlsConfig,
|
||||
config: config,
|
||||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
dialer: dialer,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
}
|
||||
|
||||
// dial dials the connection
|
||||
func (c *client) dial() error {
|
||||
var err error
|
||||
if c.dialer != nil {
|
||||
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
||||
} else {
|
||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// once the version has been negotiated, open the header stream
|
||||
c.headerStream, err = c.session.OpenStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.requestWriter = newRequestWriter(c.headerStream, c.logger)
|
||||
go c.handleHeaderStream()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleHeaderStream() {
|
||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||
h2framer := http2.NewFramer(nil, c.headerStream)
|
||||
|
||||
var err error
|
||||
for err == nil {
|
||||
err = c.readResponse(h2framer, decoder)
|
||||
}
|
||||
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
|
||||
c.logger.Debugf("Error handling header stream: %s", err)
|
||||
}
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
|
||||
// stop all running request
|
||||
close(c.headerErrored)
|
||||
}
|
||||
|
||||
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
|
||||
frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hframe, ok := frame.(*http2.HeadersFrame)
|
||||
if !ok {
|
||||
return errors.New("not a headers frame")
|
||||
}
|
||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read header fields: %s", err.Error())
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||
c.mutex.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
|
||||
}
|
||||
|
||||
rsp, err := responseFromHeaders(mhframe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
responseChan <- rsp
|
||||
return nil
|
||||
}
|
||||
|
||||
// Roundtrip executes a request and returns a response
|
||||
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// TODO: add port to address, if it doesn't have one
|
||||
if req.URL.Scheme != "https" {
|
||||
return nil, errors.New("quic http2: unsupported scheme")
|
||||
}
|
||||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
||||
return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
||||
}
|
||||
|
||||
c.dialOnce.Do(func() {
|
||||
c.handshakeErr = c.dial()
|
||||
})
|
||||
|
||||
if c.handshakeErr != nil {
|
||||
return nil, c.handshakeErr
|
||||
}
|
||||
|
||||
hasBody := (req.Body != nil)
|
||||
|
||||
responseChan := make(chan *http.Response)
|
||||
dataStream, err := c.session.OpenStreamSync()
|
||||
if err != nil {
|
||||
_ = c.closeWithError(err)
|
||||
return nil, err
|
||||
}
|
||||
c.mutex.Lock()
|
||||
c.responses[dataStream.StreamID()] = responseChan
|
||||
c.mutex.Unlock()
|
||||
|
||||
var requestedGzip bool
|
||||
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
||||
requestedGzip = true
|
||||
}
|
||||
// TODO: add support for trailers
|
||||
endStream := !hasBody
|
||||
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
||||
if err != nil {
|
||||
_ = c.closeWithError(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resc := make(chan error, 1)
|
||||
if hasBody {
|
||||
go func() {
|
||||
resc <- c.writeRequestBody(dataStream, req.Body)
|
||||
}()
|
||||
}
|
||||
|
||||
var res *http.Response
|
||||
|
||||
var receivedResponse bool
|
||||
var bodySent bool
|
||||
|
||||
if !hasBody {
|
||||
bodySent = true
|
||||
}
|
||||
|
||||
ctx := req.Context()
|
||||
for !(bodySent && receivedResponse) {
|
||||
select {
|
||||
case res = <-responseChan:
|
||||
receivedResponse = true
|
||||
c.mutex.Lock()
|
||||
delete(c.responses, dataStream.StreamID())
|
||||
c.mutex.Unlock()
|
||||
case err := <-resc:
|
||||
bodySent = true
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
// error code 6 signals that stream was canceled
|
||||
dataStream.CancelRead(6)
|
||||
dataStream.CancelWrite(6)
|
||||
c.mutex.Lock()
|
||||
delete(c.responses, dataStream.StreamID())
|
||||
c.mutex.Unlock()
|
||||
return nil, ctx.Err()
|
||||
case <-c.headerErrored:
|
||||
// an error occurred on the header stream
|
||||
_ = c.closeWithError(c.headerErr)
|
||||
return nil, c.headerErr
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: correctly set this variable
|
||||
var streamEnded bool
|
||||
isHead := (req.Method == "HEAD")
|
||||
|
||||
res = setLength(res, isHead, streamEnded)
|
||||
|
||||
if streamEnded || isHead {
|
||||
res.Body = noBody
|
||||
} else {
|
||||
res.Body = dataStream
|
||||
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
||||
res.Header.Del("Content-Encoding")
|
||||
res.Header.Del("Content-Length")
|
||||
res.ContentLength = -1
|
||||
res.Body = &gzipReader{body: res.Body}
|
||||
res.Uncompressed = true
|
||||
}
|
||||
}
|
||||
|
||||
res.Request = req
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
|
||||
defer func() {
|
||||
cerr := body.Close()
|
||||
if err == nil {
|
||||
// TODO: what to do with dataStream here? Maybe reset it?
|
||||
err = cerr
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = io.Copy(dataStream, body)
|
||||
if err != nil {
|
||||
// TODO: what to do with dataStream here? Maybe reset it?
|
||||
return err
|
||||
}
|
||||
return dataStream.Close()
|
||||
}
|
||||
|
||||
func (c *client) closeWithError(e error) error {
|
||||
if c.session == nil {
|
||||
return nil
|
||||
}
|
||||
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
|
||||
}
|
||||
|
||||
// Close closes the client
|
||||
func (c *client) Close() error {
|
||||
if c.session == nil {
|
||||
return nil
|
||||
}
|
||||
return c.session.Close()
|
||||
}
|
||||
|
||||
// copied from net/transport.go
|
||||
|
||||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
||||
// and returns a host:port. The port 443 is added if needed.
|
||||
func authorityAddr(scheme string, authority string) (addr string) {
|
||||
host, port, err := net.SplitHostPort(authority)
|
||||
if err != nil { // authority didn't have a port
|
||||
port = "443"
|
||||
if scheme == "http" {
|
||||
port = "80"
|
||||
}
|
||||
host = authority
|
||||
}
|
||||
if a, err := idna.ToASCII(host); err == nil {
|
||||
host = a
|
||||
}
|
||||
// IPv6 address literal, without a port:
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
return host + ":" + port
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
package h2quic
|
||||
|
||||
// copied from net/transport.go
|
||||
|
||||
// gzipReader wraps a response body so it can lazily
|
||||
// call gzip.NewReader on the first call to Read
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
)
|
||||
|
||||
// call gzip.NewReader on the first call to Read
|
||||
type gzipReader struct {
|
||||
body io.ReadCloser // underlying Response.Body
|
||||
zr *gzip.Reader // lazily-initialized gzip reader
|
||||
zerr error // sticky error
|
||||
}
|
||||
|
||||
func (gz *gzipReader) Read(p []byte) (n int, err error) {
|
||||
if gz.zerr != nil {
|
||||
return 0, gz.zerr
|
||||
}
|
||||
if gz.zr == nil {
|
||||
gz.zr, err = gzip.NewReader(gz.body)
|
||||
if err != nil {
|
||||
gz.zerr = err
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return gz.zr.Read(p)
|
||||
}
|
||||
|
||||
func (gz *gzipReader) Close() error {
|
||||
return gz.body.Close()
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
|
||||
var path, authority, method, contentLengthStr string
|
||||
httpHeaders := http.Header{}
|
||||
|
||||
for _, h := range headers {
|
||||
switch h.Name {
|
||||
case ":path":
|
||||
path = h.Value
|
||||
case ":method":
|
||||
method = h.Value
|
||||
case ":authority":
|
||||
authority = h.Value
|
||||
case "content-length":
|
||||
contentLengthStr = h.Value
|
||||
default:
|
||||
if !h.IsPseudo() {
|
||||
httpHeaders.Add(h.Name, h.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
|
||||
if len(httpHeaders["Cookie"]) > 0 {
|
||||
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
|
||||
}
|
||||
|
||||
if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
|
||||
return nil, errors.New(":path, :authority and :method must not be empty")
|
||||
}
|
||||
|
||||
u, err := url.Parse(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var contentLength int64
|
||||
if len(contentLengthStr) > 0 {
|
||||
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
Method: method,
|
||||
URL: u,
|
||||
Proto: "HTTP/2.0",
|
||||
ProtoMajor: 2,
|
||||
ProtoMinor: 0,
|
||||
Header: httpHeaders,
|
||||
Body: nil,
|
||||
ContentLength: contentLength,
|
||||
Host: authority,
|
||||
RequestURI: path,
|
||||
TLS: &tls.ConnectionState{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func hostnameFromRequest(req *http.Request) string {
|
||||
if req.URL != nil {
|
||||
return req.URL.Host
|
||||
}
|
||||
return ""
|
||||
}
|
|
@ -1,29 +0,0 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
type requestBody struct {
|
||||
requestRead bool
|
||||
dataStream quic.Stream
|
||||
}
|
||||
|
||||
// make sure the requestBody can be used as a http.Request.Body
|
||||
var _ io.ReadCloser = &requestBody{}
|
||||
|
||||
func newRequestBody(stream quic.Stream) *requestBody {
|
||||
return &requestBody{dataStream: stream}
|
||||
}
|
||||
|
||||
func (b *requestBody) Read(p []byte) (int, error) {
|
||||
b.requestRead = true
|
||||
return b.dataStream.Read(p)
|
||||
}
|
||||
|
||||
func (b *requestBody) Close() error {
|
||||
// stream's Close() closes the write side, not the read side
|
||||
return nil
|
||||
}
|
|
@ -1,203 +0,0 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type requestWriter struct {
|
||||
mutex sync.Mutex
|
||||
headerStream quic.Stream
|
||||
|
||||
henc *hpack.Encoder
|
||||
hbuf bytes.Buffer // HPACK encoder writes into this
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
const defaultUserAgent = "quic-go"
|
||||
|
||||
func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
|
||||
rw := &requestWriter{
|
||||
headerStream: headerStream,
|
||||
logger: logger,
|
||||
}
|
||||
rw.henc = hpack.NewEncoder(&rw.hbuf)
|
||||
return rw
|
||||
}
|
||||
|
||||
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
|
||||
// TODO: add support for trailers
|
||||
// TODO: add support for gzip compression
|
||||
// TODO: write continuation frames, if the header frame is too long
|
||||
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
|
||||
h2framer := http2.NewFramer(w.headerStream, nil)
|
||||
return h2framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: uint32(dataStreamID),
|
||||
EndHeaders: true,
|
||||
EndStream: endStream,
|
||||
BlockFragment: w.hbuf.Bytes(),
|
||||
Priority: http2.PriorityParam{Weight: 0xff},
|
||||
})
|
||||
}
|
||||
|
||||
// the rest of this files is copied from http2.Transport
|
||||
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
|
||||
w.hbuf.Reset()
|
||||
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
host, err := httpguts.PunycodeHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var path string
|
||||
if req.Method != "CONNECT" {
|
||||
path = req.URL.RequestURI()
|
||||
if !validPseudoPath(path) {
|
||||
orig := path
|
||||
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
|
||||
if !validPseudoPath(path) {
|
||||
if req.URL.Opaque != "" {
|
||||
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid request :path %q", orig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for any invalid headers and return an error before we
|
||||
// potentially pollute our hpack state. (We want to be able to
|
||||
// continue to reuse the hpack encoder for future requests)
|
||||
for k, vv := range req.Header {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("invalid HTTP header name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 8.1.2.3 Request Pseudo-Header Fields
|
||||
// The :path pseudo-header field includes the path and query parts of the
|
||||
// target URI (the path-absolute production and optionally a '?' character
|
||||
// followed by the query production (see Sections 3.3 and 3.4 of
|
||||
// [RFC3986]).
|
||||
w.writeHeader(":authority", host)
|
||||
w.writeHeader(":method", req.Method)
|
||||
if req.Method != "CONNECT" {
|
||||
w.writeHeader(":path", path)
|
||||
w.writeHeader(":scheme", req.URL.Scheme)
|
||||
}
|
||||
if trailers != "" {
|
||||
w.writeHeader("trailer", trailers)
|
||||
}
|
||||
|
||||
var didUA bool
|
||||
for k, vv := range req.Header {
|
||||
lowKey := strings.ToLower(k)
|
||||
switch lowKey {
|
||||
case "host", "content-length":
|
||||
// Host is :authority, already sent.
|
||||
// Content-Length is automatic, set below.
|
||||
continue
|
||||
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
|
||||
// Per 8.1.2.2 Connection-Specific Header
|
||||
// Fields, don't send connection-specific
|
||||
// fields. We have already checked if any
|
||||
// are error-worthy so just ignore the rest.
|
||||
continue
|
||||
case "user-agent":
|
||||
// Match Go's http1 behavior: at most one
|
||||
// User-Agent. If set to nil or empty string,
|
||||
// then omit it. Otherwise if not mentioned,
|
||||
// include the default (below).
|
||||
didUA = true
|
||||
if len(vv) < 1 {
|
||||
continue
|
||||
}
|
||||
vv = vv[:1]
|
||||
if vv[0] == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, v := range vv {
|
||||
w.writeHeader(lowKey, v)
|
||||
}
|
||||
}
|
||||
if shouldSendReqContentLength(req.Method, contentLength) {
|
||||
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
|
||||
}
|
||||
if addGzipHeader {
|
||||
w.writeHeader("accept-encoding", "gzip")
|
||||
}
|
||||
if !didUA {
|
||||
w.writeHeader("user-agent", defaultUserAgent)
|
||||
}
|
||||
return w.hbuf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (w *requestWriter) writeHeader(name, value string) {
|
||||
w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
|
||||
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
||||
}
|
||||
|
||||
// shouldSendReqContentLength reports whether the http2.Transport should send
|
||||
// a "content-length" request header. This logic is basically a copy of the net/http
|
||||
// transferWriter.shouldSendContentLength.
|
||||
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
||||
// -1 means unknown.
|
||||
func shouldSendReqContentLength(method string, contentLength int64) bool {
|
||||
if contentLength > 0 {
|
||||
return true
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return false
|
||||
}
|
||||
// For zero bodies, whether we send a content-length depends on the method.
|
||||
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
||||
switch method {
|
||||
case "POST", "PUT", "PATCH":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validPseudoPath(v string) bool {
|
||||
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
|
||||
}
|
||||
|
||||
// actualContentLength returns a sanitized version of
|
||||
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
||||
// means unknown.
|
||||
func actualContentLength(req *http.Request) int64 {
|
||||
if req.Body == nil {
|
||||
return 0
|
||||
}
|
||||
if req.ContentLength != 0 {
|
||||
return req.ContentLength
|
||||
}
|
||||
return -1
|
||||
}
|
|
@ -1,95 +0,0 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// copied from net/http2/transport.go
|
||||
|
||||
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
|
||||
var noBody = ioutil.NopCloser(bytes.NewReader(nil))
|
||||
|
||||
// from the handleResponse function
|
||||
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
|
||||
if f.Truncated {
|
||||
return nil, errResponseHeaderListSize
|
||||
}
|
||||
|
||||
status := f.PseudoValue("status")
|
||||
if status == "" {
|
||||
return nil, errors.New("missing status pseudo header")
|
||||
}
|
||||
statusCode, err := strconv.Atoi(status)
|
||||
if err != nil {
|
||||
return nil, errors.New("malformed non-numeric status pseudo header")
|
||||
}
|
||||
|
||||
// TODO: handle statusCode == 100
|
||||
|
||||
header := make(http.Header)
|
||||
res := &http.Response{
|
||||
Proto: "HTTP/2.0",
|
||||
ProtoMajor: 2,
|
||||
Header: header,
|
||||
StatusCode: statusCode,
|
||||
Status: status + " " + http.StatusText(statusCode),
|
||||
}
|
||||
for _, hf := range f.RegularFields() {
|
||||
key := http.CanonicalHeaderKey(hf.Name)
|
||||
if key == "Trailer" {
|
||||
t := res.Trailer
|
||||
if t == nil {
|
||||
t = make(http.Header)
|
||||
res.Trailer = t
|
||||
}
|
||||
foreachHeaderElement(hf.Value, func(v string) {
|
||||
t[http.CanonicalHeaderKey(v)] = nil
|
||||
})
|
||||
} else {
|
||||
header[key] = append(header[key], hf.Value)
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// continuation of the handleResponse function
|
||||
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
|
||||
if !streamEnded || isHead {
|
||||
res.ContentLength = -1
|
||||
if clens := res.Header["Content-Length"]; len(clens) == 1 {
|
||||
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
|
||||
res.ContentLength = clen64
|
||||
}
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// copied from net/http/server.go
|
||||
|
||||
// foreachHeaderElement splits v according to the "#rule" construction
|
||||
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
|
||||
func foreachHeaderElement(v string, fn func(string)) {
|
||||
v = textproto.TrimString(v)
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
if !strings.Contains(v, ",") {
|
||||
fn(v)
|
||||
return
|
||||
}
|
||||
for _, f := range strings.Split(v, ",") {
|
||||
if f = textproto.TrimString(f); f != "" {
|
||||
fn(f)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,114 +0,0 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
type responseWriter struct {
|
||||
dataStreamID protocol.StreamID
|
||||
dataStream quic.Stream
|
||||
|
||||
headerStream quic.Stream
|
||||
headerStreamMutex *sync.Mutex
|
||||
|
||||
header http.Header
|
||||
status int // status code passed to WriteHeader
|
||||
headerWritten bool
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newResponseWriter(
|
||||
headerStream quic.Stream,
|
||||
headerStreamMutex *sync.Mutex,
|
||||
dataStream quic.Stream,
|
||||
dataStreamID protocol.StreamID,
|
||||
logger utils.Logger,
|
||||
) *responseWriter {
|
||||
return &responseWriter{
|
||||
header: http.Header{},
|
||||
headerStream: headerStream,
|
||||
headerStreamMutex: headerStreamMutex,
|
||||
dataStream: dataStream,
|
||||
dataStreamID: dataStreamID,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteHeader(status int) {
|
||||
if w.headerWritten {
|
||||
return
|
||||
}
|
||||
w.headerWritten = true
|
||||
w.status = status
|
||||
|
||||
var headers bytes.Buffer
|
||||
enc := hpack.NewEncoder(&headers)
|
||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
|
||||
|
||||
for k, v := range w.header {
|
||||
for index := range v {
|
||||
enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
|
||||
}
|
||||
}
|
||||
|
||||
w.logger.Infof("Responding with %d", status)
|
||||
w.headerStreamMutex.Lock()
|
||||
defer w.headerStreamMutex.Unlock()
|
||||
h2framer := http2.NewFramer(w.headerStream, nil)
|
||||
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: uint32(w.dataStreamID),
|
||||
EndHeaders: true,
|
||||
BlockFragment: headers.Bytes(),
|
||||
})
|
||||
if err != nil {
|
||||
w.logger.Errorf("could not write h2 header: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Write(p []byte) (int, error) {
|
||||
if !w.headerWritten {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
if !bodyAllowedForStatus(w.status) {
|
||||
return 0, http.ErrBodyNotAllowed
|
||||
}
|
||||
return w.dataStream.Write(p)
|
||||
}
|
||||
|
||||
func (w *responseWriter) Flush() {}
|
||||
|
||||
// This is a NOP. Use http.Request.Context
|
||||
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
||||
|
||||
// test that we implement http.Flusher
|
||||
var _ http.Flusher = &responseWriter{}
|
||||
|
||||
// copied from http2/http2.go
|
||||
// bodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC 2616, section 4.4.
|
||||
func bodyAllowedForStatus(status int) bool {
|
||||
switch {
|
||||
case status >= 100 && status <= 199:
|
||||
return false
|
||||
case status == 204:
|
||||
return false
|
||||
case status == 304:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
9
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_closenotifier.go
generated
vendored
9
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_closenotifier.go
generated
vendored
|
@ -1,9 +0,0 @@
|
|||
package h2quic
|
||||
|
||||
import "net/http"
|
||||
|
||||
// The CloseNotifier is a deprecated interface, and staticcheck will report that from Go 1.11.
|
||||
// By defining it in a separate file, we can exclude this file from staticcheck.
|
||||
|
||||
// test that we implement http.CloseNotifier
|
||||
var _ http.CloseNotifier = &responseWriter{}
|
|
@ -1,179 +0,0 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
type roundTripCloser interface {
|
||||
http.RoundTripper
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// RoundTripper implements the http.RoundTripper interface
|
||||
type RoundTripper struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
// DisableCompression, if true, prevents the Transport from
|
||||
// requesting compression with an "Accept-Encoding: gzip"
|
||||
// request header when the Request contains no existing
|
||||
// Accept-Encoding value. If the Transport requests gzip on
|
||||
// its own and gets a gzipped response, it's transparently
|
||||
// decoded in the Response.Body. However, if the user
|
||||
// explicitly requested gzip it is not automatically
|
||||
// uncompressed.
|
||||
DisableCompression bool
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use with
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
// QuicConfig is the quic.Config used for dialing new connections.
|
||||
// If nil, reasonable default values will be used.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Dial specifies an optional dial function for creating QUIC
|
||||
// connections for requests.
|
||||
// If Dial is nil, quic.DialAddr will be used.
|
||||
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
|
||||
clients map[string]roundTripCloser
|
||||
}
|
||||
|
||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||
type RoundTripOpt struct {
|
||||
// OnlyCachedConn controls whether the RoundTripper may
|
||||
// create a new QUIC connection. If set true and
|
||||
// no cached connection is available, RoundTrip
|
||||
// will return ErrNoCachedConn.
|
||||
OnlyCachedConn bool
|
||||
}
|
||||
|
||||
var _ roundTripCloser = &RoundTripper{}
|
||||
|
||||
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
||||
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
|
||||
|
||||
// RoundTripOpt is like RoundTrip, but takes options.
|
||||
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
if req.URL == nil {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("quic: nil Request.URL")
|
||||
}
|
||||
if req.URL.Host == "" {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("quic: no Host in request URL")
|
||||
}
|
||||
if req.Header == nil {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("quic: nil Request.Header")
|
||||
}
|
||||
|
||||
if req.URL.Scheme == "https" {
|
||||
for k, vv := range req.Header {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
closeRequestBody(req)
|
||||
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
|
||||
}
|
||||
|
||||
if req.Method != "" && !validMethod(req.Method) {
|
||||
closeRequestBody(req)
|
||||
return nil, fmt.Errorf("quic: invalid method %q", req.Method)
|
||||
}
|
||||
|
||||
hostname := authorityAddr("https", hostnameFromRequest(req))
|
||||
cl, err := r.getClient(hostname, opt.OnlyCachedConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cl.RoundTrip(req)
|
||||
}
|
||||
|
||||
// RoundTrip does a round trip.
|
||||
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return r.RoundTripOpt(req, RoundTripOpt{})
|
||||
}
|
||||
|
||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if r.clients == nil {
|
||||
r.clients = make(map[string]roundTripCloser)
|
||||
}
|
||||
|
||||
client, ok := r.clients[hostname]
|
||||
if !ok {
|
||||
if onlyCached {
|
||||
return nil, ErrNoCachedConn
|
||||
}
|
||||
client = newClient(
|
||||
hostname,
|
||||
r.TLSClientConfig,
|
||||
&roundTripperOpts{DisableCompression: r.DisableCompression},
|
||||
r.QuicConfig,
|
||||
r.Dial,
|
||||
)
|
||||
r.clients[hostname] = client
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Close closes the QUIC connections that this RoundTripper has used
|
||||
func (r *RoundTripper) Close() error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for _, client := range r.clients {
|
||||
if err := client.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.clients = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeRequestBody(req *http.Request) {
|
||||
if req.Body != nil {
|
||||
req.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func validMethod(method string) bool {
|
||||
/*
|
||||
Method = "OPTIONS" ; Section 9.2
|
||||
| "GET" ; Section 9.3
|
||||
| "HEAD" ; Section 9.4
|
||||
| "POST" ; Section 9.5
|
||||
| "PUT" ; Section 9.6
|
||||
| "DELETE" ; Section 9.7
|
||||
| "TRACE" ; Section 9.8
|
||||
| "CONNECT" ; Section 9.9
|
||||
| extension-method
|
||||
extension-method = token
|
||||
token = 1*<any CHAR except CTLs or separators>
|
||||
*/
|
||||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
||||
}
|
||||
|
||||
// copied from net/http/http.go
|
||||
func isNotToken(r rune) bool {
|
||||
return !httpguts.IsTokenRune(r)
|
||||
}
|
|
@ -1,402 +0,0 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
type streamCreator interface {
|
||||
quic.Session
|
||||
GetOrOpenStream(protocol.StreamID) (quic.Stream, error)
|
||||
}
|
||||
|
||||
type remoteCloser interface {
|
||||
CloseRemote(protocol.ByteCount)
|
||||
}
|
||||
|
||||
// allows mocking of quic.Listen and quic.ListenAddr
|
||||
var (
|
||||
quicListen = quic.Listen
|
||||
quicListenAddr = quic.ListenAddr
|
||||
)
|
||||
|
||||
// Server is a HTTP2 server listening for QUIC connections.
|
||||
type Server struct {
|
||||
*http.Server
|
||||
|
||||
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
|
||||
// If nil, it uses reasonable default values.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Private flag for demo, do not use
|
||||
CloseAfterFirstRequest bool
|
||||
|
||||
port uint32 // used atomically
|
||||
|
||||
listenerMutex sync.Mutex
|
||||
listener quic.Listener
|
||||
closed bool
|
||||
|
||||
supportedVersionsAsString string
|
||||
|
||||
logger utils.Logger // will be set by Server.serveImpl()
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
||||
func (s *Server) ListenAndServe() error {
|
||||
if s.Server == nil {
|
||||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
return s.serveImpl(s.TLSConfig, nil)
|
||||
}
|
||||
|
||||
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
||||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We currently only use the cert-related stuff from tls.Config,
|
||||
// so we don't need to make a full copy.
|
||||
config := &tls.Config{
|
||||
Certificates: certs,
|
||||
}
|
||||
return s.serveImpl(config, nil)
|
||||
}
|
||||
|
||||
// Serve an existing UDP connection.
|
||||
func (s *Server) Serve(conn net.PacketConn) error {
|
||||
return s.serveImpl(s.TLSConfig, conn)
|
||||
}
|
||||
|
||||
func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
||||
if s.Server == nil {
|
||||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
s.logger = utils.DefaultLogger.WithPrefix("server")
|
||||
s.listenerMutex.Lock()
|
||||
if s.closed {
|
||||
s.listenerMutex.Unlock()
|
||||
return errors.New("Server is already closed")
|
||||
}
|
||||
if s.listener != nil {
|
||||
s.listenerMutex.Unlock()
|
||||
return errors.New("ListenAndServe may only be called once")
|
||||
}
|
||||
|
||||
var ln quic.Listener
|
||||
var err error
|
||||
if conn == nil {
|
||||
ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
|
||||
} else {
|
||||
ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
|
||||
}
|
||||
if err != nil {
|
||||
s.listenerMutex.Unlock()
|
||||
return err
|
||||
}
|
||||
s.listener = ln
|
||||
s.listenerMutex.Unlock()
|
||||
|
||||
for {
|
||||
sess, err := ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.handleHeaderStream(sess.(streamCreator))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleHeaderStream(session streamCreator) {
|
||||
stream, err := session.AcceptStream()
|
||||
if err != nil {
|
||||
session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
|
||||
return
|
||||
}
|
||||
|
||||
hpackDecoder := hpack.NewDecoder(4096, nil)
|
||||
h2framer := http2.NewFramer(nil, stream)
|
||||
|
||||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
||||
for {
|
||||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
||||
// QuicErrors must originate from stream.Read() returning an error.
|
||||
// In this case, the session has already logged the error, so we don't
|
||||
// need to log it again.
|
||||
errorCode := qerr.InternalError
|
||||
if qerr, ok := err.(*qerr.QuicError); !ok {
|
||||
errorCode = qerr.ErrorCode
|
||||
s.logger.Errorf("error handling h2 request: %s", err.Error())
|
||||
}
|
||||
session.CloseWithError(quic.ErrorCode(errorCode), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
||||
h2frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
||||
}
|
||||
var h2headersFrame *http2.HeadersFrame
|
||||
switch f := h2frame.(type) {
|
||||
case *http2.PriorityFrame:
|
||||
// ignore PRIORITY frames
|
||||
s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
|
||||
return nil
|
||||
case *http2.HeadersFrame:
|
||||
h2headersFrame = f
|
||||
default:
|
||||
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
|
||||
}
|
||||
|
||||
if !h2headersFrame.HeadersEnded() {
|
||||
return errors.New("http2 header continuation not implemented")
|
||||
}
|
||||
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := requestFromHeaders(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.logger.Debug() {
|
||||
s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
||||
} else {
|
||||
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
|
||||
}
|
||||
|
||||
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request
|
||||
if dataStream == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleRequest should be as non-blocking as possible to minimize
|
||||
// head-of-line blocking. Potentially blocking code is run in a separate
|
||||
// goroutine, enabling handleRequest to return before the code is executed.
|
||||
go func() {
|
||||
streamEnded := h2headersFrame.StreamEnded()
|
||||
if streamEnded {
|
||||
dataStream.(remoteCloser).CloseRemote(0)
|
||||
streamEnded = true
|
||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
||||
}
|
||||
|
||||
req = req.WithContext(dataStream.Context())
|
||||
reqBody := newRequestBody(dataStream)
|
||||
req.Body = reqBody
|
||||
|
||||
req.RemoteAddr = session.RemoteAddr().String()
|
||||
|
||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
|
||||
|
||||
handler := s.Handler
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
panicked := false
|
||||
func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
// Copied from net/http/server.go
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
|
||||
panicked = true
|
||||
}
|
||||
}()
|
||||
handler.ServeHTTP(responseWriter, req)
|
||||
}()
|
||||
if panicked {
|
||||
responseWriter.WriteHeader(500)
|
||||
} else {
|
||||
responseWriter.WriteHeader(200)
|
||||
}
|
||||
if responseWriter.dataStream != nil {
|
||||
if !streamEnded && !reqBody.requestRead {
|
||||
// in gQUIC, the error code doesn't matter, so just use 0 here
|
||||
responseWriter.dataStream.CancelRead(0)
|
||||
}
|
||||
responseWriter.dataStream.Close()
|
||||
}
|
||||
if s.CloseAfterFirstRequest {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
session.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
|
||||
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
||||
func (s *Server) Close() error {
|
||||
s.listenerMutex.Lock()
|
||||
defer s.listenerMutex.Unlock()
|
||||
s.closed = true
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
|
||||
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
||||
func (s *Server) CloseGracefully(timeout time.Duration) error {
|
||||
// TODO: implement
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
|
||||
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
|
||||
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
|
||||
func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
||||
port := atomic.LoadUint32(&s.port)
|
||||
|
||||
if port == 0 {
|
||||
// Extract port from s.Server.Addr
|
||||
_, portStr, err := net.SplitHostPort(s.Server.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
portInt, err := net.LookupPort("tcp", portStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
port = uint32(portInt)
|
||||
atomic.StoreUint32(&s.port, port)
|
||||
}
|
||||
|
||||
if s.supportedVersionsAsString == "" {
|
||||
var versions []string
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
versions = append(versions, v.ToAltSvc())
|
||||
}
|
||||
s.supportedVersionsAsString = strings.Join(versions, ",")
|
||||
}
|
||||
|
||||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListenAndServeQUIC listens on the UDP network address addr and calls the
|
||||
// handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is
|
||||
// used when handler is nil.
|
||||
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
server := &Server{
|
||||
Server: &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
},
|
||||
}
|
||||
return server.ListenAndServeTLS(certFile, keyFile)
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the given network address for both, TLS and QUIC
|
||||
// connetions in parallel. It returns if one of the two returns an error.
|
||||
// http.DefaultServeMux is used when handler is nil.
|
||||
// The correct Alt-Svc headers for QUIC are set.
|
||||
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
// Load certs
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We currently only use the cert-related stuff from tls.Config,
|
||||
// so we don't need to make a full copy.
|
||||
config := &tls.Config{
|
||||
Certificates: certs,
|
||||
}
|
||||
|
||||
// Open the listeners
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tcpConn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tcpConn.Close()
|
||||
|
||||
tlsConn := tls.NewListener(tcpConn, config)
|
||||
defer tlsConn.Close()
|
||||
|
||||
// Start the servers
|
||||
httpServer := &http.Server{
|
||||
Addr: addr,
|
||||
TLSConfig: config,
|
||||
}
|
||||
|
||||
quicServer := &Server{
|
||||
Server: httpServer,
|
||||
}
|
||||
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
quicServer.SetQuicHeaders(w.Header())
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
hErr := make(chan error)
|
||||
qErr := make(chan error)
|
||||
go func() {
|
||||
hErr <- httpServer.Serve(tlsConn)
|
||||
}()
|
||||
go func() {
|
||||
qErr <- quicServer.Serve(udpConn)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-hErr:
|
||||
quicServer.Close()
|
||||
return err
|
||||
case err := <-qErr:
|
||||
// Cannot close the HTTP server or wait for requests to complete properly :/
|
||||
return err
|
||||
}
|
||||
}
|
|
@ -16,19 +16,11 @@ type StreamID = protocol.StreamID
|
|||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
|
||||
const (
|
||||
// VersionGQUIC39 is gQUIC version 39.
|
||||
VersionGQUIC39 = protocol.Version39
|
||||
// VersionGQUIC43 is gQUIC version 43.
|
||||
VersionGQUIC43 = protocol.Version43
|
||||
// VersionGQUIC44 is gQUIC version 44.
|
||||
VersionGQUIC44 = protocol.Version44
|
||||
// VersionMilestone0_10_0 uses TLS
|
||||
VersionMilestone0_10_0 = protocol.VersionMilestone0_10_0
|
||||
)
|
||||
|
||||
// A Cookie can be used to verify the ownership of the client address.
|
||||
type Cookie = handshake.Cookie
|
||||
type Cookie struct {
|
||||
RemoteAddr string
|
||||
SentTime time.Time
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about the QUIC connection.
|
||||
type ConnectionState = handshake.ConnectionState
|
||||
|
@ -166,11 +158,7 @@ type Config struct {
|
|||
// If not set, it uses all versions available.
|
||||
// Warning: This API should not be considered stable and will change soon.
|
||||
Versions []VersionNumber
|
||||
// Ask the server to omit the connection ID sent in the Public Header.
|
||||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
||||
// Currently only valid for the client.
|
||||
RequestConnectionIDOmission bool
|
||||
// The length of the connection ID in bytes. Only valid for IETF QUIC.
|
||||
// The length of the connection ID in bytes.
|
||||
// It can be 0, or any value between 4 and 18.
|
||||
// If not set, the interpretation depends on where the Config is used:
|
||||
// If used for dialing an address, a 0 byte connection ID will be used.
|
||||
|
@ -200,13 +188,10 @@ type Config struct {
|
|||
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any bidirectional streams.
|
||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||
MaxIncomingStreams int
|
||||
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
||||
// This value doesn't have any effect in Google QUIC.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||
MaxIncomingUniStreams int
|
||||
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
||||
KeepAlive bool
|
||||
|
|
|
@ -27,11 +27,12 @@ type SentPacketHandler interface {
|
|||
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
||||
ShouldSendNumPackets() int
|
||||
|
||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
DequeuePacketForRetransmission() *Packet
|
||||
DequeueProbePacket() (*Packet, error)
|
||||
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
|
||||
|
||||
PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen)
|
||||
PopPacketNumber() protocol.PacketNumber
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm() error
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
package quic
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// The packetNumberGenerator generates the packet number for the next packet
|
||||
|
@ -15,13 +16,17 @@ type packetNumberGenerator struct {
|
|||
|
||||
next protocol.PacketNumber
|
||||
nextToSkip protocol.PacketNumber
|
||||
|
||||
history []protocol.PacketNumber
|
||||
}
|
||||
|
||||
func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
|
||||
return &packetNumberGenerator{
|
||||
g := &packetNumberGenerator{
|
||||
next: initial,
|
||||
averagePeriod: averagePeriod,
|
||||
}
|
||||
g.generateNewSkip()
|
||||
return g
|
||||
}
|
||||
|
||||
func (p *packetNumberGenerator) Peek() protocol.PacketNumber {
|
||||
|
@ -35,6 +40,10 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber {
|
|||
p.next++
|
||||
|
||||
if p.next == p.nextToSkip {
|
||||
if len(p.history)+1 > protocol.MaxTrackedSkippedPackets {
|
||||
p.history = p.history[1:]
|
||||
}
|
||||
p.history = append(p.history, p.next)
|
||||
p.next++
|
||||
p.generateNewSkip()
|
||||
}
|
||||
|
@ -42,28 +51,28 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber {
|
|||
return next
|
||||
}
|
||||
|
||||
func (p *packetNumberGenerator) generateNewSkip() error {
|
||||
num, err := p.getRandomNumber()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *packetNumberGenerator) generateNewSkip() {
|
||||
num := p.getRandomNumber()
|
||||
skip := protocol.PacketNumber(num) * (p.averagePeriod - 1) / (math.MaxUint16 / 2)
|
||||
// make sure that there are never two consecutive packet numbers that are skipped
|
||||
p.nextToSkip = p.next + 2 + skip
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRandomNumber() generates a cryptographically secure random number between 0 and MaxUint16 (= 65535)
|
||||
// The expectation value is 65535/2
|
||||
func (p *packetNumberGenerator) getRandomNumber() (uint16, error) {
|
||||
func (p *packetNumberGenerator) getRandomNumber() uint16 {
|
||||
b := make([]byte, 2)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rand.Read(b) // ignore the error here
|
||||
|
||||
num := uint16(b[0])<<8 + uint16(b[1])
|
||||
return num, nil
|
||||
return num
|
||||
}
|
||||
|
||||
func (p *packetNumberGenerator) Validate(ack *wire.AckFrame) bool {
|
||||
for _, pn := range p.history {
|
||||
if ack.AcksPacket(pn) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -2,9 +2,9 @@ package ackhandler
|
|||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// The receivedPacketHistory stores if a packet number has already been received.
|
||||
|
|
2
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/retransmittable.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/retransmittable.go
generated
vendored
|
@ -16,8 +16,6 @@ func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
|
|||
// IsFrameRetransmittable returns true if the frame should be retransmitted.
|
||||
func IsFrameRetransmittable(f wire.Frame) bool {
|
||||
switch f.(type) {
|
||||
case *wire.StopWaitingFrame:
|
||||
return false
|
||||
case *wire.AckFrame:
|
||||
return false
|
||||
default:
|
||||
|
|
81
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
81
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
|
@ -8,9 +8,9 @@ import (
|
|||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -30,12 +30,13 @@ const (
|
|||
)
|
||||
|
||||
type sentPacketHandler struct {
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
packetNumberGenerator *packetNumberGenerator
|
||||
|
||||
lastSentRetransmittablePacketTime time.Time
|
||||
lastSentHandshakePacketTime time.Time
|
||||
|
||||
nextPacketSendTime time.Time
|
||||
skippedPackets []protocol.PacketNumber
|
||||
|
||||
largestAcked protocol.PacketNumber
|
||||
largestReceivedPacketWithAck protocol.PacketNumber
|
||||
|
@ -45,8 +46,7 @@ type sentPacketHandler struct {
|
|||
lowestPacketNotConfirmedAcked protocol.PacketNumber
|
||||
largestSentBeforeRTO protocol.PacketNumber
|
||||
|
||||
packetHistory *sentPacketHistory
|
||||
stopWaitingManager stopWaitingManager
|
||||
packetHistory *sentPacketHistory
|
||||
|
||||
retransmissionQueue []*Packet
|
||||
|
||||
|
@ -90,12 +90,12 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
|
|||
)
|
||||
|
||||
return &sentPacketHandler{
|
||||
packetHistory: newSentPacketHistory(),
|
||||
stopWaitingManager: stopWaitingManager{},
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
version: version,
|
||||
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
|
||||
packetHistory: newSentPacketHistory(),
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,13 +110,13 @@ func (h *sentPacketHandler) SetHandshakeComplete() {
|
|||
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
|
||||
var queue []*Packet
|
||||
for _, packet := range h.retransmissionQueue {
|
||||
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
|
||||
if packet.EncryptionLevel == protocol.Encryption1RTT {
|
||||
queue = append(queue, packet)
|
||||
}
|
||||
}
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
|
||||
if p.EncryptionLevel != protocol.Encryption1RTT {
|
||||
handshakePackets = append(handshakePackets, p)
|
||||
}
|
||||
return true, nil
|
||||
|
@ -148,10 +148,7 @@ func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retra
|
|||
|
||||
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
|
||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||
h.skippedPackets = append(h.skippedPackets, p)
|
||||
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
|
||||
h.skippedPackets = h.skippedPackets[1:]
|
||||
}
|
||||
h.logger.Debugf("Skipping packet number %#x", p)
|
||||
}
|
||||
|
||||
h.lastSentPacketNumber = packet.PacketNumber
|
||||
|
@ -166,7 +163,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
|
|||
isRetransmittable := len(packet.Frames) != 0
|
||||
|
||||
if isRetransmittable {
|
||||
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
if packet.EncryptionLevel != protocol.Encryption1RTT {
|
||||
h.lastSentHandshakePacketTime = packet.SendTime
|
||||
}
|
||||
h.lastSentRetransmittablePacketTime = packet.SendTime
|
||||
|
@ -198,7 +195,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
h.largestReceivedPacketWithAck = withPacketNumber
|
||||
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
|
||||
|
||||
if h.skippedPacketsAcked(ackFrame) {
|
||||
if !h.packetNumberGenerator.Validate(ackFrame) {
|
||||
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||
}
|
||||
|
||||
|
@ -213,9 +210,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
|
||||
priorInFlight := h.bytesInFlight
|
||||
for _, p := range ackedPackets {
|
||||
if encLevel < p.EncryptionLevel {
|
||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
|
||||
}
|
||||
// TODO(#1534): check the encryption level
|
||||
// if encLevel < p.EncryptionLevel {
|
||||
// return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
|
||||
// }
|
||||
|
||||
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||
|
@ -234,10 +233,6 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
return err
|
||||
}
|
||||
h.updateLossDetectionAlarm()
|
||||
|
||||
h.garbageCollectSkippedPackets()
|
||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -519,12 +514,13 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
|
|||
return h.DequeuePacketForRetransmission(), nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
|
||||
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||
pn := h.packetNumberGenerator.Peek()
|
||||
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
||||
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
||||
return h.packetNumberGenerator.Pop()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SendMode() SendMode {
|
||||
|
@ -585,7 +581,7 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
|||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
if p.canBeRetransmitted && p.EncryptionLevel != protocol.Encryption1RTT {
|
||||
handshakePackets = append(handshakePackets, p)
|
||||
}
|
||||
return true, nil
|
||||
|
@ -607,7 +603,6 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
|
|||
return err
|
||||
}
|
||||
h.retransmissionQueue = append(h.retransmissionQueue, p)
|
||||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -633,26 +628,6 @@ func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
|||
}
|
||||
rto = utils.MaxDuration(rto, minRTOTimeout)
|
||||
// Exponential backoff
|
||||
rto = rto << h.rtoCount
|
||||
rto <<= h.rtoCount
|
||||
return utils.MinDuration(rto, maxRTOTimeout)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
|
||||
for _, p := range h.skippedPackets {
|
||||
if ackFrame.AcksPacket(p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
||||
lowestUnacked := h.lowestUnacked()
|
||||
deleteIndex := 0
|
||||
for i, p := range h.skippedPackets {
|
||||
if p < lowestUnacked {
|
||||
deleteIndex = i + 1
|
||||
}
|
||||
}
|
||||
h.skippedPackets = h.skippedPackets[deleteIndex:]
|
||||
}
|
||||
|
|
6
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
6
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
|
@ -35,7 +35,7 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
|
|||
}
|
||||
if p.canBeRetransmitted {
|
||||
h.numOutstandingPackets++
|
||||
if p.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
if p.EncryptionLevel != protocol.Encryption1RTT {
|
||||
h.numOutstandingHandshakePackets++
|
||||
}
|
||||
}
|
||||
|
@ -106,7 +106,7 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber)
|
|||
if h.numOutstandingPackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
|
||||
h.numOutstandingHandshakePackets--
|
||||
if h.numOutstandingHandshakePackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
|
@ -147,7 +147,7 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
|
|||
if h.numOutstandingPackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
|
||||
h.numOutstandingHandshakePackets--
|
||||
if h.numOutstandingHandshakePackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
|
|
43
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go
generated
vendored
43
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go
generated
vendored
|
@ -1,43 +0,0 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
|
||||
type stopWaitingManager struct {
|
||||
largestLeastUnackedSent protocol.PacketNumber
|
||||
nextLeastUnacked protocol.PacketNumber
|
||||
|
||||
lastStopWaitingFrame *wire.StopWaitingFrame
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
|
||||
if force {
|
||||
return s.lastStopWaitingFrame
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
s.largestLeastUnackedSent = s.nextLeastUnacked
|
||||
swf := &wire.StopWaitingFrame{
|
||||
LeastUnacked: s.nextLeastUnacked,
|
||||
}
|
||||
s.lastStopWaitingFrame = swf
|
||||
return swf
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
|
||||
largestAcked := ack.LargestAcked()
|
||||
if largestAcked >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = largestAcked + 1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) QueuedRetransmissionForPacketNumber(p protocol.PacketNumber) {
|
||||
if p >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = p + 1
|
||||
}
|
||||
}
|
|
@ -193,7 +193,7 @@ func (c *cubicSender) OnPacketLost(
|
|||
if c.congestionWindow >= 2*c.initialCongestionWindow {
|
||||
c.minSlowStartExitWindow = c.congestionWindow / 2
|
||||
}
|
||||
c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS
|
||||
c.congestionWindow -= protocol.DefaultTCPMSS
|
||||
} else if c.reno {
|
||||
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
|
||||
} else {
|
||||
|
|
|
@ -1,72 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/lucas-clemente/aes12"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type aeadAESGCM12 struct {
|
||||
otherIV []byte
|
||||
myIV []byte
|
||||
encrypter cipher.AEAD
|
||||
decrypter cipher.AEAD
|
||||
}
|
||||
|
||||
var _ AEAD = &aeadAESGCM12{}
|
||||
|
||||
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
|
||||
//
|
||||
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
||||
// tag size, and couples the cipher and aes packages closely.
|
||||
// See https://github.com/lucas-clemente/aes12.
|
||||
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
|
||||
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
|
||||
}
|
||||
encrypterCipher, err := aes12.NewCipher(myKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encrypter, err := aes12.NewGCM(encrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypterCipher, err := aes12.NewCipher(otherKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypter, err := aes12.NewGCM(decrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &aeadAESGCM12{
|
||||
otherIV: otherIV,
|
||||
myIV: myIV,
|
||||
encrypter: encrypter,
|
||||
decrypter: decrypter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||
res := make([]byte, 12)
|
||||
copy(res[0:4], iv)
|
||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||
return res
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) Overhead() int {
|
||||
return aead.encrypter.Overhead()
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
|
||||
"github.com/hashicorp/golang-lru"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
||||
compressedCertsCache *lru.Cache
|
||||
)
|
||||
|
||||
func getCompressedCert(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
||||
// Hash all inputs
|
||||
hasher := fnv.New64a()
|
||||
for _, v := range chain {
|
||||
hasher.Write(v)
|
||||
}
|
||||
hasher.Write(pCommonSetHashes)
|
||||
hasher.Write(pCachedHashes)
|
||||
hash := hasher.Sum64()
|
||||
|
||||
var result []byte
|
||||
|
||||
resultI, isCached := compressedCertsCache.Get(hash)
|
||||
if isCached {
|
||||
result = resultI.([]byte)
|
||||
} else {
|
||||
var err error
|
||||
result, err = compressChain(chain, pCommonSetHashes, pCachedHashes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
compressedCertsCache.Add(hash, result)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
compressedCertsCache, err = lru.New(protocol.NumCachedCertificates)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("fatal error in quic-go: could not create lru cache: %s", err.Error()))
|
||||
}
|
||||
}
|
|
@ -1,113 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A CertChain holds a certificate and a private key
|
||||
type CertChain interface {
|
||||
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
|
||||
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
|
||||
GetLeafCert(sni string) ([]byte, error)
|
||||
}
|
||||
|
||||
// proofSource stores a key and a certificate for the server proof
|
||||
type certChain struct {
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
var _ CertChain = &certChain{}
|
||||
|
||||
var errNoMatchingCertificate = errors.New("no matching certificate found")
|
||||
|
||||
// NewCertChain loads the key and cert from files
|
||||
func NewCertChain(tlsConfig *tls.Config) CertChain {
|
||||
return &certChain{config: tlsConfig}
|
||||
}
|
||||
|
||||
// SignServerProof signs CHLO and server config for use in the server proof
|
||||
func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
|
||||
cert, err := c.getCertForSNI(sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signServerProof(cert, chlo, serverConfigData)
|
||||
}
|
||||
|
||||
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
|
||||
func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
||||
cert, err := c.getCertForSNI(sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes)
|
||||
}
|
||||
|
||||
// GetLeafCert gets the leaf certificate
|
||||
func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
|
||||
cert, err := c.getCertForSNI(sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cert.Certificate[0], nil
|
||||
}
|
||||
|
||||
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
||||
conf := c.config
|
||||
conf, err := maybeGetConfigForClient(conf, sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The rest of this function is mostly copied from crypto/tls.getCertificate
|
||||
|
||||
if conf.GetCertificate != nil {
|
||||
cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
|
||||
if cert != nil || err != nil {
|
||||
return cert, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(conf.Certificates) == 0 {
|
||||
return nil, errNoMatchingCertificate
|
||||
}
|
||||
|
||||
if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
|
||||
// There's only one choice, so no point doing any work.
|
||||
return &conf.Certificates[0], nil
|
||||
}
|
||||
|
||||
name := strings.ToLower(sni)
|
||||
for len(name) > 0 && name[len(name)-1] == '.' {
|
||||
name = name[:len(name)-1]
|
||||
}
|
||||
|
||||
if cert, ok := conf.NameToCertificate[name]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// try replacing labels in the name with wildcards until we get a
|
||||
// match.
|
||||
labels := strings.Split(name, ".")
|
||||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if cert, ok := conf.NameToCertificate[candidate]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If nothing matches, return the first certificate.
|
||||
return &conf.Certificates[0], nil
|
||||
}
|
||||
|
||||
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
|
||||
if c.GetConfigForClient == nil {
|
||||
return c, nil
|
||||
}
|
||||
return c.GetConfigForClient(&tls.ClientHelloInfo{
|
||||
ServerName: sni,
|
||||
})
|
||||
}
|
|
@ -1,272 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type entryType uint8
|
||||
|
||||
const (
|
||||
entryCompressed entryType = 1
|
||||
entryCached entryType = 2
|
||||
entryCommon entryType = 3
|
||||
)
|
||||
|
||||
type entry struct {
|
||||
t entryType
|
||||
h uint64 // set hash
|
||||
i uint32 // index
|
||||
}
|
||||
|
||||
func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
||||
res := &bytes.Buffer{}
|
||||
|
||||
cachedHashes, err := splitHashes(pCachedHashes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setHashes, err := splitHashes(pCommonSetHashes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chainHashes := make([]uint64, len(chain))
|
||||
for i := range chain {
|
||||
chainHashes[i] = HashCert(chain[i])
|
||||
}
|
||||
|
||||
entries := buildEntries(chain, chainHashes, cachedHashes, setHashes)
|
||||
|
||||
totalUncompressedLen := 0
|
||||
for i, e := range entries {
|
||||
res.WriteByte(uint8(e.t))
|
||||
switch e.t {
|
||||
case entryCached:
|
||||
utils.LittleEndian.WriteUint64(res, e.h)
|
||||
case entryCommon:
|
||||
utils.LittleEndian.WriteUint64(res, e.h)
|
||||
utils.LittleEndian.WriteUint32(res, e.i)
|
||||
case entryCompressed:
|
||||
totalUncompressedLen += 4 + len(chain[i])
|
||||
}
|
||||
}
|
||||
res.WriteByte(0) // end of list
|
||||
|
||||
if totalUncompressedLen > 0 {
|
||||
gz, err := zlib.NewWriterLevelDict(res, flate.BestCompression, buildZlibDictForEntries(entries, chain))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
|
||||
}
|
||||
|
||||
utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen))
|
||||
|
||||
for i, e := range entries {
|
||||
if e.t != entryCompressed {
|
||||
continue
|
||||
}
|
||||
lenCert := len(chain[i])
|
||||
gz.Write([]byte{
|
||||
byte(lenCert & 0xff),
|
||||
byte((lenCert >> 8) & 0xff),
|
||||
byte((lenCert >> 16) & 0xff),
|
||||
byte((lenCert >> 24) & 0xff),
|
||||
})
|
||||
gz.Write(chain[i])
|
||||
}
|
||||
|
||||
gz.Close()
|
||||
}
|
||||
|
||||
return res.Bytes(), nil
|
||||
}
|
||||
|
||||
func decompressChain(data []byte) ([][]byte, error) {
|
||||
var chain [][]byte
|
||||
var entries []entry
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
var numCerts int
|
||||
var hasCompressedCerts bool
|
||||
for {
|
||||
entryTypeByte, err := r.ReadByte()
|
||||
if entryTypeByte == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
et := entryType(entryTypeByte)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
numCerts++
|
||||
|
||||
switch et {
|
||||
case entryCached:
|
||||
// we're not sending any certificate hashes in the CHLO, so there shouldn't be any cached certificates in the chain
|
||||
return nil, errors.New("unexpected cached certificate")
|
||||
case entryCommon:
|
||||
e := entry{t: entryCommon}
|
||||
e.h, err = utils.LittleEndian.ReadUint64(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.i, err = utils.LittleEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certSet, ok := certSets[e.h]
|
||||
if !ok {
|
||||
return nil, errors.New("unknown certSet")
|
||||
}
|
||||
if e.i >= uint32(len(certSet)) {
|
||||
return nil, errors.New("certificate not found in certSet")
|
||||
}
|
||||
entries = append(entries, e)
|
||||
chain = append(chain, certSet[e.i])
|
||||
case entryCompressed:
|
||||
hasCompressedCerts = true
|
||||
entries = append(entries, entry{t: entryCompressed})
|
||||
chain = append(chain, nil)
|
||||
default:
|
||||
return nil, errors.New("unknown entryType")
|
||||
}
|
||||
}
|
||||
|
||||
if numCerts == 0 {
|
||||
return make([][]byte, 0), nil
|
||||
}
|
||||
|
||||
if hasCompressedCerts {
|
||||
uncompressedLength, err := utils.LittleEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
fmt.Println(4)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zlibDict := buildZlibDictForEntries(entries, chain)
|
||||
gz, err := zlib.NewReaderDict(r, zlibDict)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
var totalLength uint32
|
||||
var certIndex int
|
||||
for totalLength < uncompressedLength {
|
||||
lenBytes := make([]byte, 4)
|
||||
_, err := gz.Read(lenBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certLen := binary.LittleEndian.Uint32(lenBytes)
|
||||
|
||||
cert := make([]byte, certLen)
|
||||
n, err := gz.Read(cert)
|
||||
if uint32(n) != certLen && err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
if certIndex >= len(entries) {
|
||||
return nil, errors.New("CertCompression BUG: no element to save uncompressed certificate")
|
||||
}
|
||||
if entries[certIndex].t == entryCompressed {
|
||||
chain[certIndex] = cert
|
||||
certIndex++
|
||||
break
|
||||
}
|
||||
certIndex++
|
||||
}
|
||||
|
||||
totalLength += 4 + certLen
|
||||
}
|
||||
}
|
||||
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry {
|
||||
res := make([]entry, len(chain))
|
||||
chainLoop:
|
||||
for i := range chain {
|
||||
// Check if hash is in cachedHashes
|
||||
for j := range cachedHashes {
|
||||
if chainHashes[i] == cachedHashes[j] {
|
||||
res[i] = entry{t: entryCached, h: chainHashes[i]}
|
||||
continue chainLoop
|
||||
}
|
||||
}
|
||||
|
||||
// Go through common sets and check if it's in there
|
||||
for _, setHash := range setHashes {
|
||||
set, ok := certSets[setHash]
|
||||
if !ok {
|
||||
// We don't have this set
|
||||
continue
|
||||
}
|
||||
// We have this set, check if chain[i] is in the set
|
||||
pos := set.findCertInSet(chain[i])
|
||||
if pos >= 0 {
|
||||
// Found
|
||||
res[i] = entry{t: entryCommon, h: setHash, i: uint32(pos)}
|
||||
continue chainLoop
|
||||
}
|
||||
}
|
||||
|
||||
res[i] = entry{t: entryCompressed}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func buildZlibDictForEntries(entries []entry, chain [][]byte) []byte {
|
||||
var dict bytes.Buffer
|
||||
|
||||
// First the cached and common in reverse order
|
||||
for i := len(entries) - 1; i >= 0; i-- {
|
||||
if entries[i].t == entryCompressed {
|
||||
continue
|
||||
}
|
||||
dict.Write(chain[i])
|
||||
}
|
||||
|
||||
dict.Write(certDictZlib)
|
||||
return dict.Bytes()
|
||||
}
|
||||
|
||||
func splitHashes(hashes []byte) ([]uint64, error) {
|
||||
if len(hashes)%8 != 0 {
|
||||
return nil, errors.New("expected a multiple of 8 bytes for CCS / CCRT hashes")
|
||||
}
|
||||
n := len(hashes) / 8
|
||||
res := make([]uint64, n)
|
||||
for i := 0; i < n; i++ {
|
||||
res[i] = binary.LittleEndian.Uint64(hashes[i*8 : (i+1)*8])
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func getCommonCertificateHashes() []byte {
|
||||
ccs := make([]byte, 8*len(certSets))
|
||||
i := 0
|
||||
for certSetHash := range certSets {
|
||||
binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash)
|
||||
i++
|
||||
}
|
||||
return ccs
|
||||
}
|
||||
|
||||
// HashCert calculates the FNV1a hash of a certificate
|
||||
func HashCert(cert []byte) uint64 {
|
||||
h := fnv.New64a()
|
||||
h.Write(cert)
|
||||
return h.Sum64()
|
||||
}
|
|
@ -1,128 +0,0 @@
|
|||
package crypto
|
||||
|
||||
var certDictZlib = []byte{
|
||||
0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04,
|
||||
0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03,
|
||||
0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x30,
|
||||
0x5f, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x04, 0x01,
|
||||
0x06, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86, 0xfd, 0x6d, 0x01, 0x07,
|
||||
0x17, 0x01, 0x30, 0x33, 0x20, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65,
|
||||
0x64, 0x20, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x20, 0x53, 0x20, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x34,
|
||||
0x20, 0x53, 0x53, 0x4c, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31,
|
||||
0x32, 0x20, 0x53, 0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x53, 0x65, 0x72,
|
||||
0x76, 0x65, 0x72, 0x20, 0x43, 0x41, 0x30, 0x2d, 0x61, 0x69, 0x61, 0x2e,
|
||||
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
|
||||
0x2f, 0x45, 0x2d, 0x63, 0x72, 0x6c, 0x2e, 0x76, 0x65, 0x72, 0x69, 0x73,
|
||||
0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x45, 0x2e, 0x63, 0x65,
|
||||
0x72, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01,
|
||||
0x01, 0x05, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x4a, 0x2e, 0x63,
|
||||
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73,
|
||||
0x2f, 0x63, 0x70, 0x73, 0x20, 0x28, 0x63, 0x29, 0x30, 0x30, 0x09, 0x06,
|
||||
0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x30, 0x0d,
|
||||
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05,
|
||||
0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x7b, 0x30, 0x1d, 0x06, 0x03, 0x55,
|
||||
0x1d, 0x0e, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86,
|
||||
0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01,
|
||||
0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xd2,
|
||||
0x6f, 0x64, 0x6f, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x43, 0x2e,
|
||||
0x63, 0x72, 0x6c, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16,
|
||||
0x04, 0x14, 0xb4, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69,
|
||||
0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x30, 0x0b, 0x06, 0x03,
|
||||
0x55, 0x1d, 0x0f, 0x04, 0x04, 0x03, 0x02, 0x01, 0x30, 0x0d, 0x06, 0x09,
|
||||
0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x30,
|
||||
0x81, 0xca, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
|
||||
0x02, 0x55, 0x53, 0x31, 0x10, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x08,
|
||||
0x13, 0x07, 0x41, 0x72, 0x69, 0x7a, 0x6f, 0x6e, 0x61, 0x31, 0x13, 0x30,
|
||||
0x11, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x0a, 0x53, 0x63, 0x6f, 0x74,
|
||||
0x74, 0x73, 0x64, 0x61, 0x6c, 0x65, 0x31, 0x1a, 0x30, 0x18, 0x06, 0x03,
|
||||
0x55, 0x04, 0x0a, 0x13, 0x11, 0x47, 0x6f, 0x44, 0x61, 0x64, 0x64, 0x79,
|
||||
0x2e, 0x63, 0x6f, 0x6d, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31, 0x33,
|
||||
0x30, 0x31, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x2a, 0x68, 0x74, 0x74,
|
||||
0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63,
|
||||
0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79,
|
||||
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74,
|
||||
0x6f, 0x72, 0x79, 0x31, 0x30, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x04, 0x03,
|
||||
0x13, 0x27, 0x47, 0x6f, 0x20, 0x44, 0x61, 0x64, 0x64, 0x79, 0x20, 0x53,
|
||||
0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66,
|
||||
0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, 0x68,
|
||||
0x6f, 0x72, 0x69, 0x74, 0x79, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55,
|
||||
0x04, 0x05, 0x13, 0x08, 0x30, 0x37, 0x39, 0x36, 0x39, 0x32, 0x38, 0x37,
|
||||
0x30, 0x1e, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d,
|
||||
0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x0c,
|
||||
0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x02, 0x30, 0x00,
|
||||
0x30, 0x1d, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff,
|
||||
0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55,
|
||||
0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05,
|
||||
0x05, 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07,
|
||||
0x03, 0x02, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff,
|
||||
0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x33, 0x06, 0x03, 0x55, 0x1d,
|
||||
0x1f, 0x04, 0x2c, 0x30, 0x2a, 0x30, 0x28, 0xa0, 0x26, 0xa0, 0x24, 0x86,
|
||||
0x22, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, 0x2e,
|
||||
0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
|
||||
0x67, 0x64, 0x73, 0x31, 0x2d, 0x32, 0x30, 0x2a, 0x30, 0x28, 0x06, 0x08,
|
||||
0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, 0x16, 0x1c, 0x68, 0x74,
|
||||
0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76, 0x65,
|
||||
0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63,
|
||||
0x70, 0x73, 0x30, 0x34, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x5a, 0x17,
|
||||
0x0d, 0x31, 0x33, 0x30, 0x35, 0x30, 0x39, 0x06, 0x08, 0x2b, 0x06, 0x01,
|
||||
0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x2d, 0x68, 0x74, 0x74, 0x70, 0x3a,
|
||||
0x2f, 0x2f, 0x73, 0x30, 0x39, 0x30, 0x37, 0x06, 0x08, 0x2b, 0x06, 0x01,
|
||||
0x05, 0x05, 0x07, 0x02, 0x30, 0x44, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04,
|
||||
0x3d, 0x30, 0x3b, 0x30, 0x39, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86,
|
||||
0xf8, 0x45, 0x01, 0x07, 0x17, 0x06, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03,
|
||||
0x55, 0x04, 0x06, 0x13, 0x02, 0x47, 0x42, 0x31, 0x1b, 0x53, 0x31, 0x17,
|
||||
0x30, 0x15, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0e, 0x56, 0x65, 0x72,
|
||||
0x69, 0x53, 0x69, 0x67, 0x6e, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31,
|
||||
0x1f, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x16, 0x56, 0x65,
|
||||
0x72, 0x69, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x54, 0x72, 0x75, 0x73, 0x74,
|
||||
0x20, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x31, 0x3b, 0x30, 0x39,
|
||||
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x32, 0x54, 0x65, 0x72, 0x6d, 0x73,
|
||||
0x20, 0x6f, 0x66, 0x20, 0x75, 0x73, 0x65, 0x20, 0x61, 0x74, 0x20, 0x68,
|
||||
0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76,
|
||||
0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
|
||||
0x72, 0x70, 0x61, 0x20, 0x28, 0x63, 0x29, 0x30, 0x31, 0x10, 0x30, 0x0e,
|
||||
0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x07, 0x53, 0x31, 0x13, 0x30, 0x11,
|
||||
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x0a, 0x47, 0x31, 0x13, 0x30, 0x11,
|
||||
0x06, 0x0b, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x3c, 0x02, 0x01,
|
||||
0x03, 0x13, 0x02, 0x55, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04,
|
||||
0x03, 0x14, 0x31, 0x19, 0x30, 0x17, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13,
|
||||
0x31, 0x1d, 0x30, 0x1b, 0x06, 0x03, 0x55, 0x04, 0x0f, 0x13, 0x14, 0x50,
|
||||
0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e,
|
||||
0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x31, 0x12, 0x31, 0x21, 0x30,
|
||||
0x1f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x18, 0x44, 0x6f, 0x6d, 0x61,
|
||||
0x69, 0x6e, 0x20, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x20, 0x56,
|
||||
0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x64, 0x31, 0x14, 0x31, 0x31,
|
||||
0x30, 0x2f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x28, 0x53, 0x65, 0x65,
|
||||
0x20, 0x77, 0x77, 0x77, 0x2e, 0x72, 0x3a, 0x2f, 0x2f, 0x73, 0x65, 0x63,
|
||||
0x75, 0x72, 0x65, 0x2e, 0x67, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53,
|
||||
0x69, 0x67, 0x6e, 0x31, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x41,
|
||||
0x2e, 0x63, 0x72, 0x6c, 0x56, 0x65, 0x72, 0x69, 0x53, 0x69, 0x67, 0x6e,
|
||||
0x20, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x20, 0x33, 0x20, 0x45, 0x63, 0x72,
|
||||
0x6c, 0x2e, 0x67, 0x65, 0x6f, 0x74, 0x72, 0x75, 0x73, 0x74, 0x2e, 0x63,
|
||||
0x6f, 0x6d, 0x2f, 0x63, 0x72, 0x6c, 0x73, 0x2f, 0x73, 0x64, 0x31, 0x1a,
|
||||
0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x3a,
|
||||
0x2f, 0x2f, 0x45, 0x56, 0x49, 0x6e, 0x74, 0x6c, 0x2d, 0x63, 0x63, 0x72,
|
||||
0x74, 0x2e, 0x67, 0x77, 0x77, 0x77, 0x2e, 0x67, 0x69, 0x63, 0x65, 0x72,
|
||||
0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x31, 0x6f, 0x63, 0x73, 0x70, 0x2e,
|
||||
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
|
||||
0x30, 0x39, 0x72, 0x61, 0x70, 0x69, 0x64, 0x73, 0x73, 0x6c, 0x2e, 0x63,
|
||||
0x6f, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63,
|
||||
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72,
|
||||
0x79, 0x2f, 0x30, 0x81, 0x80, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05,
|
||||
0x07, 0x01, 0x01, 0x04, 0x74, 0x30, 0x72, 0x30, 0x24, 0x06, 0x08, 0x2b,
|
||||
0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, 0x18, 0x68, 0x74, 0x74,
|
||||
0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, 0x73, 0x70, 0x2e, 0x67, 0x6f, 0x64,
|
||||
0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x30, 0x4a, 0x06,
|
||||
0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x3e, 0x68,
|
||||
0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66,
|
||||
0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64,
|
||||
0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73,
|
||||
0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, 0x67, 0x64, 0x5f, 0x69, 0x6e, 0x74,
|
||||
0x65, 0x72, 0x6d, 0x65, 0x64, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, 0x72,
|
||||
0x74, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16,
|
||||
0x80, 0x14, 0xfd, 0xac, 0x61, 0x32, 0x93, 0x6c, 0x45, 0xd6, 0xe2, 0xee,
|
||||
0x85, 0x5f, 0x9a, 0xba, 0xe7, 0x76, 0x99, 0x68, 0xcc, 0xe7, 0x30, 0x27,
|
||||
0x86, 0x29, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x86, 0x30,
|
||||
0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73,
|
||||
}
|
|
@ -1,135 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// CertManager manages the certificates sent by the server
|
||||
type CertManager interface {
|
||||
SetData([]byte) error
|
||||
GetCommonCertificateHashes() []byte
|
||||
GetLeafCert() []byte
|
||||
GetLeafCertHash() (uint64, error)
|
||||
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
|
||||
Verify(hostname string) error
|
||||
GetChain() []*x509.Certificate
|
||||
}
|
||||
|
||||
type certManager struct {
|
||||
chain []*x509.Certificate
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
var _ CertManager = &certManager{}
|
||||
|
||||
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded")
|
||||
|
||||
// NewCertManager creates a new CertManager
|
||||
func NewCertManager(tlsConfig *tls.Config) CertManager {
|
||||
return &certManager{config: tlsConfig}
|
||||
}
|
||||
|
||||
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
|
||||
func (c *certManager) SetData(data []byte) error {
|
||||
byteChain, err := decompressChain(data)
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
|
||||
}
|
||||
|
||||
chain := make([]*x509.Certificate, len(byteChain))
|
||||
for i, data := range byteChain {
|
||||
cert, err := x509.ParseCertificate(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chain[i] = cert
|
||||
}
|
||||
|
||||
c.chain = chain
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *certManager) GetChain() []*x509.Certificate {
|
||||
return c.chain
|
||||
}
|
||||
|
||||
func (c *certManager) GetCommonCertificateHashes() []byte {
|
||||
return getCommonCertificateHashes()
|
||||
}
|
||||
|
||||
// GetLeafCert returns the leaf certificate of the certificate chain
|
||||
// it returns nil if the certificate chain has not yet been set
|
||||
func (c *certManager) GetLeafCert() []byte {
|
||||
if len(c.chain) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.chain[0].Raw
|
||||
}
|
||||
|
||||
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
|
||||
func (c *certManager) GetLeafCertHash() (uint64, error) {
|
||||
leafCert := c.GetLeafCert()
|
||||
if leafCert == nil {
|
||||
return 0, errNoCertificateChain
|
||||
}
|
||||
|
||||
h := fnv.New64a()
|
||||
_, err := h.Write(leafCert)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return h.Sum64(), nil
|
||||
}
|
||||
|
||||
// VerifyServerProof verifies the signature of the server config
|
||||
// it should only be called after the certificate chain has been set, otherwise it returns false
|
||||
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool {
|
||||
if len(c.chain) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData)
|
||||
}
|
||||
|
||||
// Verify verifies the certificate chain
|
||||
func (c *certManager) Verify(hostname string) error {
|
||||
if len(c.chain) == 0 {
|
||||
return errNoCertificateChain
|
||||
}
|
||||
|
||||
if c.config != nil && c.config.InsecureSkipVerify {
|
||||
return nil
|
||||
}
|
||||
|
||||
leafCert := c.chain[0]
|
||||
|
||||
var opts x509.VerifyOptions
|
||||
if c.config != nil {
|
||||
opts.Roots = c.config.RootCAs
|
||||
if c.config.Time == nil {
|
||||
opts.CurrentTime = time.Now()
|
||||
} else {
|
||||
opts.CurrentTime = c.config.Time()
|
||||
}
|
||||
}
|
||||
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
|
||||
opts.DNSName = hostname
|
||||
|
||||
// the first certificate is the leaf certificate, all others are intermediates
|
||||
if len(c.chain) > 1 {
|
||||
intermediates := x509.NewCertPool()
|
||||
for i := 1; i < len(c.chain); i++ {
|
||||
intermediates.AddCert(c.chain[i])
|
||||
}
|
||||
opts.Intermediates = intermediates
|
||||
}
|
||||
|
||||
_, err := leafCert.Verify(opts)
|
||||
return err
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go-certificates"
|
||||
)
|
||||
|
||||
type certSet [][]byte
|
||||
|
||||
var certSets = map[uint64]certSet{
|
||||
certsets.CertSet2Hash: certsets.CertSet2,
|
||||
certsets.CertSet3Hash: certsets.CertSet3,
|
||||
}
|
||||
|
||||
// findCertInSet searches for the cert in the set. Negative return value means not found.
|
||||
func (s *certSet) findCertInSet(cert []byte) int {
|
||||
for i, c := range *s {
|
||||
if bytes.Equal(c, cert) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
61
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go
generated
vendored
61
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go
generated
vendored
|
@ -1,61 +0,0 @@
|
|||
// +build ignore
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/aead/chacha20"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type aeadChacha20Poly1305 struct {
|
||||
otherIV []byte
|
||||
myIV []byte
|
||||
encrypter cipher.AEAD
|
||||
decrypter cipher.AEAD
|
||||
}
|
||||
|
||||
// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305
|
||||
func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||
if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 {
|
||||
return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs")
|
||||
}
|
||||
// copy because ChaCha20Poly1305 expects array pointers
|
||||
var MyKey, OtherKey [32]byte
|
||||
copy(MyKey[:], myKey)
|
||||
copy(OtherKey[:], otherKey)
|
||||
|
||||
encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &aeadChacha20Poly1305{
|
||||
otherIV: otherIV,
|
||||
myIV: myIV,
|
||||
encrypter: encrypter,
|
||||
decrypter: decrypter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||
res := make([]byte, 12)
|
||||
copy(res[0:4], iv)
|
||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||
return res
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
// KeyExchange manages the exchange of keys
|
||||
type curve25519KEX struct {
|
||||
secret [32]byte
|
||||
public [32]byte
|
||||
}
|
||||
|
||||
var _ KeyExchange = &curve25519KEX{}
|
||||
|
||||
// NewCurve25519KEX creates a new KeyExchange using Curve25519, see https://cr.yp.to/ecdh.html
|
||||
func NewCurve25519KEX() (KeyExchange, error) {
|
||||
c := &curve25519KEX{}
|
||||
if _, err := rand.Read(c.secret[:]); err != nil {
|
||||
return nil, errors.New("Curve25519: could not create private key")
|
||||
}
|
||||
curve25519.ScalarBaseMult(&c.public, &c.secret)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *curve25519KEX) PublicKey() []byte {
|
||||
return c.public[:]
|
||||
}
|
||||
|
||||
func (c *curve25519KEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
|
||||
if len(otherPublic) != 32 {
|
||||
return nil, errors.New("Curve25519: expected public key of 32 byte")
|
||||
}
|
||||
var res [32]byte
|
||||
var otherPublicArray [32]byte
|
||||
copy(otherPublicArray[:], otherPublic)
|
||||
curve25519.ScalarMult(&res, &c.secret, &otherPublicArray)
|
||||
return res[:], nil
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// copied from https://github.com/cloudflare/tls-tris/blob/master/hkdf.go
|
||||
func hkdfExtract(hash crypto.Hash, secret, salt []byte) []byte {
|
||||
if salt == nil {
|
||||
salt = make([]byte, hash.Size())
|
||||
}
|
||||
if secret == nil {
|
||||
secret = make([]byte, hash.Size())
|
||||
}
|
||||
extractor := hmac.New(hash.New, salt)
|
||||
extractor.Write(secret)
|
||||
return extractor.Sum(nil)
|
||||
}
|
||||
|
||||
// copied from https://github.com/cloudflare/tls-tris/blob/master/hkdf.go
|
||||
func hkdfExpand(hash crypto.Hash, prk, info []byte, l int) []byte {
|
||||
var (
|
||||
expander = hmac.New(hash.New, prk)
|
||||
res = make([]byte, l)
|
||||
counter = byte(1)
|
||||
prev []byte
|
||||
)
|
||||
|
||||
if l > 255*expander.Size() {
|
||||
panic("hkdf: requested too much output")
|
||||
}
|
||||
|
||||
p := res
|
||||
for len(p) > 0 {
|
||||
expander.Reset()
|
||||
expander.Write(prev)
|
||||
expander.Write(info)
|
||||
expander.Write([]byte{counter})
|
||||
prev = expander.Sum(prev[:0])
|
||||
counter++
|
||||
n := copy(p, prev)
|
||||
p = p[n:]
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// hkdfExpandLabel HKDF expands a label
|
||||
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, length int) []byte {
|
||||
const prefix = "quic "
|
||||
qlabel := make([]byte, 2 /* length */ +1 /* length of label */ +len(prefix)+len(label)+1 /* length of context (empty) */)
|
||||
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
|
||||
qlabel[2] = uint8(len(prefix) + len(label))
|
||||
copy(qlabel[3:], []byte(prefix+label))
|
||||
return hkdfExpand(hash, secret, qlabel, length)
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
clientExporterLabel = "EXPORTER-QUIC client 1rtt"
|
||||
serverExporterLabel = "EXPORTER-QUIC server 1rtt"
|
||||
)
|
||||
|
||||
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
||||
type TLSExporter interface {
|
||||
ConnectionState() mint.ConnectionState
|
||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||
}
|
||||
|
||||
func qhkdfExpand(secret []byte, label string, length int) []byte {
|
||||
qlabel := make([]byte, 2+1+5+len(label))
|
||||
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
|
||||
qlabel[2] = uint8(5 + len(label))
|
||||
copy(qlabel[3:], []byte("QUIC "+label))
|
||||
return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length)
|
||||
}
|
||||
|
||||
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
|
||||
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
|
||||
var myLabel, otherLabel string
|
||||
if pers == protocol.PerspectiveClient {
|
||||
myLabel = clientExporterLabel
|
||||
otherLabel = serverExporterLabel
|
||||
} else {
|
||||
myLabel = serverExporterLabel
|
||||
otherLabel = clientExporterLabel
|
||||
}
|
||||
myKey, myIV, err := computeKeyAndIV(tls, myLabel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
||||
}
|
||||
|
||||
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
|
||||
cs := tls.ConnectionState().CipherSuite
|
||||
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
key = qhkdfExpand(secret, "key", cs.KeyLen)
|
||||
iv = qhkdfExpand(secret, "iv", cs.IvLen)
|
||||
return key, iv, nil
|
||||
}
|
100
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go
generated
vendored
100
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go
generated
vendored
|
@ -1,100 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// DeriveKeysChacha20 derives the client and server keys and creates a matching chacha20poly1305 AEAD instance
|
||||
// func DeriveKeysChacha20(version protocol.VersionNumber, forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte) (AEAD, error) {
|
||||
// otherKey, myKey, otherIV, myIV, err := deriveKeys(version, forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 32)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV)
|
||||
// }
|
||||
|
||||
// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance
|
||||
func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
|
||||
var swap bool
|
||||
if pers == protocol.PerspectiveClient {
|
||||
swap = true
|
||||
}
|
||||
otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16, swap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV)
|
||||
}
|
||||
|
||||
// deriveKeys derives the keys and the IVs
|
||||
// swap should be set true if generating the values for the client, and false for the server
|
||||
func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int, swap bool) ([]byte, []byte, []byte, []byte, error) {
|
||||
var info bytes.Buffer
|
||||
if forwardSecure {
|
||||
info.Write([]byte("QUIC forward secure key expansion\x00"))
|
||||
} else {
|
||||
info.Write([]byte("QUIC key expansion\x00"))
|
||||
}
|
||||
info.Write(connID)
|
||||
info.Write(chlo)
|
||||
info.Write(scfg)
|
||||
info.Write(cert)
|
||||
|
||||
r := hkdf.New(sha256.New, sharedSecret, nonces, info.Bytes())
|
||||
|
||||
s := make([]byte, 2*keyLen+2*4)
|
||||
if _, err := io.ReadFull(r, s); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
key1 := s[:keyLen]
|
||||
key2 := s[keyLen : 2*keyLen]
|
||||
iv1 := s[2*keyLen : 2*keyLen+4]
|
||||
iv2 := s[2*keyLen+4:]
|
||||
|
||||
var otherKey, myKey []byte
|
||||
var otherIV, myIV []byte
|
||||
|
||||
if !forwardSecure {
|
||||
if err := diversify(key2, iv2, divNonce); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if swap {
|
||||
otherKey = key2
|
||||
myKey = key1
|
||||
otherIV = iv2
|
||||
myIV = iv1
|
||||
} else {
|
||||
otherKey = key1
|
||||
myKey = key2
|
||||
otherIV = iv1
|
||||
myIV = iv2
|
||||
}
|
||||
|
||||
return otherKey, myKey, otherIV, myIV, nil
|
||||
}
|
||||
|
||||
func diversify(key, iv, divNonce []byte) error {
|
||||
secret := make([]byte, len(key)+len(iv))
|
||||
copy(secret, key)
|
||||
copy(secret[len(key):], iv)
|
||||
|
||||
r := hkdf.New(sha256.New, secret, divNonce, []byte("QUIC key diversification"))
|
||||
|
||||
if _, err := io.ReadFull(r, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.ReadFull(r, iv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
package crypto
|
||||
|
||||
// KeyExchange manages the exchange of keys
|
||||
type KeyExchange interface {
|
||||
PublicKey() []byte
|
||||
CalculateSharedKey(otherPublic []byte) ([]byte, error)
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
// NewNullAEAD creates a NullAEAD
|
||||
func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) {
|
||||
if v.UsesTLS() {
|
||||
return newNullAEADAESGCM(connID, p)
|
||||
}
|
||||
return &nullAEADFNV128a{perspective: p}, nil
|
||||
}
|
|
@ -3,13 +3,13 @@ package crypto
|
|||
import (
|
||||
"crypto"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
|
||||
|
||||
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
||||
// NewNullAEAD creates a NullAEAD
|
||||
func NewNullAEAD(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
||||
clientSecret, serverSecret := computeSecrets(connectionID)
|
||||
|
||||
var mySecret, otherSecret []byte
|
||||
|
@ -28,14 +28,14 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
|
|||
}
|
||||
|
||||
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||
handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
|
||||
clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
|
||||
serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
|
||||
initialSecret := hkdfExtract(crypto.SHA256, connID, quicVersion1Salt)
|
||||
clientSecret = HkdfExpandLabel(crypto.SHA256, initialSecret, "client in", crypto.SHA256.Size())
|
||||
serverSecret = HkdfExpandLabel(crypto.SHA256, initialSecret, "server in", crypto.SHA256.Size())
|
||||
return
|
||||
}
|
||||
|
||||
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
|
||||
key = qhkdfExpand(secret, "key", 16)
|
||||
iv = qhkdfExpand(secret, "iv", 12)
|
||||
key = HkdfExpandLabel(crypto.SHA256, secret, "key", 16)
|
||||
iv = HkdfExpandLabel(crypto.SHA256, secret, "iv", 12)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// nullAEAD handles not-yet encrypted packets
|
||||
type nullAEADFNV128a struct {
|
||||
perspective protocol.Perspective
|
||||
}
|
||||
|
||||
var _ AEAD = &nullAEADFNV128a{}
|
||||
|
||||
// Open and verify the ciphertext
|
||||
func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
if len(src) < 12 {
|
||||
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
|
||||
}
|
||||
|
||||
hash := fnv.New128a()
|
||||
hash.Write(associatedData)
|
||||
hash.Write(src[12:])
|
||||
if n.perspective == protocol.PerspectiveServer {
|
||||
hash.Write([]byte("Client"))
|
||||
} else {
|
||||
hash.Write([]byte("Server"))
|
||||
}
|
||||
sum := make([]byte, 0, 16)
|
||||
sum = hash.Sum(sum)
|
||||
// The tag is written in little endian, so we need to reverse the slice.
|
||||
reverse(sum)
|
||||
|
||||
if !bytes.Equal(sum[:12], src[:12]) {
|
||||
return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
|
||||
}
|
||||
return src[12:], nil
|
||||
}
|
||||
|
||||
// Seal writes hash and ciphertext to the buffer
|
||||
func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
if cap(dst) < 12+len(src) {
|
||||
dst = make([]byte, 12+len(src))
|
||||
} else {
|
||||
dst = dst[:12+len(src)]
|
||||
}
|
||||
|
||||
hash := fnv.New128a()
|
||||
hash.Write(associatedData)
|
||||
hash.Write(src)
|
||||
|
||||
if n.perspective == protocol.PerspectiveServer {
|
||||
hash.Write([]byte("Server"))
|
||||
} else {
|
||||
hash.Write([]byte("Client"))
|
||||
}
|
||||
sum := make([]byte, 0, 16)
|
||||
sum = hash.Sum(sum)
|
||||
// The tag is written in little endian, so we need to reverse the slice.
|
||||
reverse(sum)
|
||||
|
||||
copy(dst[12:], src)
|
||||
copy(dst, sum[:12])
|
||||
return dst
|
||||
}
|
||||
|
||||
func (n *nullAEADFNV128a) Overhead() int {
|
||||
return 12
|
||||
}
|
||||
|
||||
func reverse(a []byte) {
|
||||
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
|
||||
a[left], a[right] = a[right], a[left]
|
||||
}
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
type ecdsaSignature struct {
|
||||
R, S *big.Int
|
||||
}
|
||||
|
||||
// signServerProof signs CHLO and server config for use in the server proof
|
||||
func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) {
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
|
||||
chloHash := sha256.Sum256(chlo)
|
||||
hash.Write([]byte{32, 0, 0, 0})
|
||||
hash.Write(chloHash[:])
|
||||
hash.Write(serverConfigData)
|
||||
|
||||
key, ok := cert.PrivateKey.(crypto.Signer)
|
||||
if !ok {
|
||||
return nil, errors.New("expected PrivateKey to implement crypto.Signer")
|
||||
}
|
||||
|
||||
opts := crypto.SignerOpts(crypto.SHA256)
|
||||
|
||||
if _, ok = key.(*rsa.PrivateKey); ok {
|
||||
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
|
||||
}
|
||||
|
||||
return key.Sign(rand.Reader, hash.Sum(nil), opts)
|
||||
}
|
||||
|
||||
// verifyServerProof verifies the server proof signature
|
||||
func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool {
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
|
||||
chloHash := sha256.Sum256(chlo)
|
||||
hash.Write([]byte{32, 0, 0, 0})
|
||||
hash.Write(chloHash[:])
|
||||
hash.Write(serverConfigData)
|
||||
|
||||
// RSA
|
||||
if cert.PublicKeyAlgorithm == x509.RSA {
|
||||
opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
|
||||
err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ECDSA
|
||||
signature := &ecdsaSignature{}
|
||||
rest, err := asn1.Unmarshal(proof, signature)
|
||||
if err != nil || len(rest) != 0 {
|
||||
return false
|
||||
}
|
||||
return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S)
|
||||
}
|
|
@ -5,8 +5,8 @@ import (
|
|||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type connectionFlowController struct {
|
||||
|
|
|
@ -19,7 +19,7 @@ type StreamFlowController interface {
|
|||
flowController
|
||||
// for receiving
|
||||
// UpdateHighestReceived should be called when a new highest offset is received
|
||||
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
|
||||
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
|
||||
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
||||
}
|
||||
|
||||
|
|
40
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
40
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
|
@ -5,8 +5,8 @@ import (
|
|||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type streamFlowController struct {
|
||||
|
@ -16,8 +16,7 @@ type streamFlowController struct {
|
|||
|
||||
queueWindowUpdate func()
|
||||
|
||||
connection connectionFlowControllerI
|
||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
||||
connection connectionFlowControllerI
|
||||
|
||||
receivedFinalOffset bool
|
||||
}
|
||||
|
@ -27,7 +26,6 @@ var _ StreamFlowController = &streamFlowController{}
|
|||
// NewStreamFlowController gets a new flow controller for a stream
|
||||
func NewStreamFlowController(
|
||||
streamID protocol.StreamID,
|
||||
contributesToConnection bool,
|
||||
cfc ConnectionFlowController,
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
|
@ -37,10 +35,9 @@ func NewStreamFlowController(
|
|||
logger utils.Logger,
|
||||
) StreamFlowController {
|
||||
return &streamFlowController{
|
||||
streamID: streamID,
|
||||
contributesToConnection: contributesToConnection,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||
streamID: streamID,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
|
@ -87,32 +84,21 @@ func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCou
|
|||
if c.checkFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow))
|
||||
}
|
||||
if c.contributesToConnection {
|
||||
return c.connection.IncrementHighestReceived(increment)
|
||||
}
|
||||
return nil
|
||||
return c.connection.IncrementHighestReceived(increment)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
c.baseFlowController.AddBytesRead(n)
|
||||
if c.contributesToConnection {
|
||||
c.connection.AddBytesRead(n)
|
||||
}
|
||||
c.connection.AddBytesRead(n)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.baseFlowController.AddBytesSent(n)
|
||||
if c.contributesToConnection {
|
||||
c.connection.AddBytesSent(n)
|
||||
}
|
||||
c.connection.AddBytesSent(n)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||
window := c.baseFlowController.sendWindowSize()
|
||||
if c.contributesToConnection {
|
||||
window = utils.MinByteCount(window, c.connection.SendWindowSize())
|
||||
}
|
||||
return window
|
||||
return utils.MinByteCount(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
|
||||
}
|
||||
|
||||
func (c *streamFlowController) MaybeQueueWindowUpdate() {
|
||||
|
@ -122,9 +108,7 @@ func (c *streamFlowController) MaybeQueueWindowUpdate() {
|
|||
if hasWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
if c.contributesToConnection {
|
||||
c.connection.MaybeQueueWindowUpdate()
|
||||
}
|
||||
c.connection.MaybeQueueWindowUpdate()
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
|
@ -140,9 +124,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
|||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
|
||||
if c.contributesToConnection {
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
|
|
58
vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go
generated
vendored
Normal file
58
vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go
generated
vendored
Normal file
|
@ -0,0 +1,58 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type sealer struct {
|
||||
iv []byte
|
||||
aead cipher.AEAD
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
}
|
||||
|
||||
var _ Sealer = &sealer{}
|
||||
|
||||
func newSealer(aead cipher.AEAD, iv []byte) Sealer {
|
||||
return &sealer{
|
||||
iv: iv,
|
||||
aead: aead,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
|
||||
return s.aead.Seal(dst, s.nonceBuf, src, ad)
|
||||
}
|
||||
|
||||
func (s *sealer) Overhead() int {
|
||||
return s.aead.Overhead()
|
||||
}
|
||||
|
||||
type opener struct {
|
||||
iv []byte
|
||||
aead cipher.AEAD
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
}
|
||||
|
||||
var _ Opener = &opener{}
|
||||
|
||||
func newOpener(aead cipher.AEAD, iv []byte) Opener {
|
||||
return &opener{
|
||||
iv: iv,
|
||||
aead: aead,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
|
||||
return o.aead.Open(dst, o.nonceBuf, src, ad)
|
||||
}
|
28
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
28
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
|
@ -5,6 +5,8 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -14,14 +16,17 @@ const (
|
|||
|
||||
// A Cookie is derived from the client address and can be used to verify the ownership of this address.
|
||||
type Cookie struct {
|
||||
RemoteAddr string
|
||||
// The time that the STK was issued (resolution 1 second)
|
||||
RemoteAddr string
|
||||
OriginalDestConnectionID protocol.ConnectionID
|
||||
// The time that the Cookie was issued (resolution 1 second)
|
||||
SentTime time.Time
|
||||
}
|
||||
|
||||
// token is the struct that is used for ASN1 serialization and deserialization
|
||||
type token struct {
|
||||
Data []byte
|
||||
RemoteAddr []byte
|
||||
OriginalDestConnectionID []byte
|
||||
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
|
@ -42,10 +47,11 @@ func NewCookieGenerator() (*CookieGenerator, error) {
|
|||
}
|
||||
|
||||
// NewToken generates a new Cookie for a given source address
|
||||
func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
||||
func (g *CookieGenerator) NewToken(raddr net.Addr, origConnID protocol.ConnectionID) ([]byte, error) {
|
||||
data, err := asn1.Marshal(token{
|
||||
Data: encodeRemoteAddr(raddr),
|
||||
Timestamp: time.Now().Unix(),
|
||||
RemoteAddr: encodeRemoteAddr(raddr),
|
||||
OriginalDestConnectionID: origConnID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -72,10 +78,14 @@ func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
|
|||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
||||
}
|
||||
return &Cookie{
|
||||
RemoteAddr: decodeRemoteAddr(t.Data),
|
||||
cookie := &Cookie{
|
||||
RemoteAddr: decodeRemoteAddr(t.RemoteAddr),
|
||||
SentTime: time.Unix(t.Timestamp, 0),
|
||||
}, nil
|
||||
}
|
||||
if len(t.OriginalDestConnectionID) > 0 {
|
||||
cookie.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID)
|
||||
}
|
||||
return cookie, nil
|
||||
}
|
||||
|
||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie
|
||||
|
|
515
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go
generated
vendored
Normal file
515
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go
generated
vendored
Normal file
|
@ -0,0 +1,515 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type messageType uint8
|
||||
|
||||
// TLS handshake message types.
|
||||
const (
|
||||
typeClientHello messageType = 1
|
||||
typeServerHello messageType = 2
|
||||
typeEncryptedExtensions messageType = 8
|
||||
typeCertificate messageType = 11
|
||||
typeCertificateRequest messageType = 13
|
||||
typeCertificateVerify messageType = 15
|
||||
typeFinished messageType = 20
|
||||
)
|
||||
|
||||
func (m messageType) String() string {
|
||||
switch m {
|
||||
case typeClientHello:
|
||||
return "ClientHello"
|
||||
case typeServerHello:
|
||||
return "ServerHello"
|
||||
case typeEncryptedExtensions:
|
||||
return "EncryptedExtensions"
|
||||
case typeCertificate:
|
||||
return "Certificate"
|
||||
case typeCertificateRequest:
|
||||
return "CertificateRequest"
|
||||
case typeCertificateVerify:
|
||||
return "CertificateVerify"
|
||||
case typeFinished:
|
||||
return "Finished"
|
||||
default:
|
||||
return fmt.Sprintf("unknown message type: %d", m)
|
||||
}
|
||||
}
|
||||
|
||||
type cryptoSetup struct {
|
||||
tlsConf *qtls.Config
|
||||
|
||||
messageChan chan []byte
|
||||
|
||||
readEncLevel protocol.EncryptionLevel
|
||||
writeEncLevel protocol.EncryptionLevel
|
||||
|
||||
handleParamsCallback func(*TransportParameters)
|
||||
|
||||
// There are two ways that an error can occur during the handshake:
|
||||
// 1. as a return value from qtls.Handshake()
|
||||
// 2. when new data is passed to the crypto setup via HandleData()
|
||||
// handshakeErrChan is closed when qtls.Handshake() errors
|
||||
handshakeErrChan chan struct{}
|
||||
// HandleData() sends errors on the messageErrChan
|
||||
messageErrChan chan error
|
||||
// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
|
||||
handshakeDone chan struct{}
|
||||
// transport parameters are sent on the receivedTransportParams, as soon as they are received
|
||||
receivedTransportParams <-chan TransportParameters
|
||||
// is closed when Close() is called
|
||||
closeChan chan struct{}
|
||||
|
||||
clientHelloWritten bool
|
||||
clientHelloWrittenChan chan struct{}
|
||||
|
||||
initialStream io.Writer
|
||||
initialAEAD crypto.AEAD
|
||||
|
||||
handshakeStream io.Writer
|
||||
handshakeOpener Opener
|
||||
handshakeSealer Sealer
|
||||
|
||||
opener Opener
|
||||
sealer Sealer
|
||||
// TODO: add a 1-RTT stream (used for session tickets)
|
||||
|
||||
receivedWriteKey chan struct{}
|
||||
receivedReadKey chan struct{}
|
||||
|
||||
logger utils.Logger
|
||||
|
||||
perspective protocol.Perspective
|
||||
}
|
||||
|
||||
var _ qtls.RecordLayer = &cryptoSetup{}
|
||||
var _ CryptoSetup = &cryptoSetup{}
|
||||
|
||||
// NewCryptoSetupClient creates a new crypto setup for the client
|
||||
func NewCryptoSetupClient(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
origConnID protocol.ConnectionID,
|
||||
connID protocol.ConnectionID,
|
||||
params *TransportParameters,
|
||||
handleParams func(*TransportParameters),
|
||||
tlsConf *tls.Config,
|
||||
initialVersion protocol.VersionNumber,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
currentVersion protocol.VersionNumber,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
||||
extHandler, receivedTransportParams := newExtensionHandlerClient(
|
||||
params,
|
||||
origConnID,
|
||||
initialVersion,
|
||||
supportedVersions,
|
||||
currentVersion,
|
||||
logger,
|
||||
)
|
||||
return newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
connID,
|
||||
extHandler,
|
||||
receivedTransportParams,
|
||||
handleParams,
|
||||
tlsConf,
|
||||
logger,
|
||||
perspective,
|
||||
)
|
||||
}
|
||||
|
||||
// NewCryptoSetupServer creates a new crypto setup for the server
|
||||
func NewCryptoSetupServer(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
params *TransportParameters,
|
||||
handleParams func(*TransportParameters),
|
||||
tlsConf *tls.Config,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
currentVersion protocol.VersionNumber,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
) (CryptoSetup, error) {
|
||||
extHandler, receivedTransportParams := newExtensionHandlerServer(
|
||||
params,
|
||||
supportedVersions,
|
||||
currentVersion,
|
||||
logger,
|
||||
)
|
||||
cs, _, err := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
connID,
|
||||
extHandler,
|
||||
receivedTransportParams,
|
||||
handleParams,
|
||||
tlsConf,
|
||||
logger,
|
||||
perspective,
|
||||
)
|
||||
return cs, err
|
||||
}
|
||||
|
||||
func newCryptoSetup(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
extHandler tlsExtensionHandler,
|
||||
transportParamChan <-chan TransportParameters,
|
||||
handleParams func(*TransportParameters),
|
||||
tlsConf *tls.Config,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
||||
initialAEAD, err := crypto.NewNullAEAD(connID, perspective)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cs := &cryptoSetup{
|
||||
initialStream: initialStream,
|
||||
initialAEAD: initialAEAD,
|
||||
handshakeStream: handshakeStream,
|
||||
readEncLevel: protocol.EncryptionInitial,
|
||||
writeEncLevel: protocol.EncryptionInitial,
|
||||
handleParamsCallback: handleParams,
|
||||
receivedTransportParams: transportParamChan,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
handshakeDone: make(chan struct{}),
|
||||
handshakeErrChan: make(chan struct{}),
|
||||
messageErrChan: make(chan error, 1),
|
||||
clientHelloWrittenChan: make(chan struct{}),
|
||||
messageChan: make(chan []byte, 100),
|
||||
receivedReadKey: make(chan struct{}),
|
||||
receivedWriteKey: make(chan struct{}),
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
qtlsConf := tlsConfigToQtlsConfig(tlsConf)
|
||||
qtlsConf.AlternativeRecordLayer = cs
|
||||
qtlsConf.GetExtensions = extHandler.GetExtensions
|
||||
qtlsConf.ReceivedExtensions = extHandler.ReceivedExtensions
|
||||
cs.tlsConf = qtlsConf
|
||||
return cs, cs.clientHelloWrittenChan, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) RunHandshake() error {
|
||||
var conn *qtls.Conn
|
||||
switch h.perspective {
|
||||
case protocol.PerspectiveClient:
|
||||
conn = qtls.Client(nil, h.tlsConf)
|
||||
case protocol.PerspectiveServer:
|
||||
conn = qtls.Server(nil, h.tlsConf)
|
||||
}
|
||||
// Handle errors that might occur when HandleData() is called.
|
||||
handshakeErrChan := make(chan error, 1)
|
||||
handshakeComplete := make(chan struct{})
|
||||
go func() {
|
||||
defer close(h.handshakeDone)
|
||||
if err := conn.Handshake(); err != nil {
|
||||
handshakeErrChan <- err
|
||||
return
|
||||
}
|
||||
close(handshakeComplete)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-h.closeChan:
|
||||
close(h.messageChan)
|
||||
// wait until the Handshake() go routine has returned
|
||||
<-handshakeErrChan
|
||||
return errors.New("Handshake aborted")
|
||||
case <-handshakeComplete: // return when the handshake is done
|
||||
return nil
|
||||
case err := <-handshakeErrChan:
|
||||
// if handleMessageFor{server,client} are waiting for some qtls action, make them return
|
||||
close(h.handshakeErrChan)
|
||||
return err
|
||||
case err := <-h.messageErrChan:
|
||||
// If the handshake errored because of an error that occurred during HandleData(),
|
||||
// that error message will be more useful than the error message generated by Handshake().
|
||||
// Close the message chan that qtls is receiving messages from.
|
||||
// This will make qtls.Handshake() return.
|
||||
// Thereby the go routine running qtls.Handshake() will return.
|
||||
close(h.messageChan)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Close() error {
|
||||
close(h.closeChan)
|
||||
// wait until qtls.Handshake() actually returned
|
||||
<-h.handshakeDone
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMessage handles a TLS handshake message.
|
||||
// It is called by the crypto streams when a new message is available.
|
||||
// It returns if it is done with messages on the same encryption level.
|
||||
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
|
||||
msgType := messageType(data[0])
|
||||
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
|
||||
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
|
||||
h.messageErrChan <- err
|
||||
return false
|
||||
}
|
||||
h.messageChan <- data
|
||||
switch h.perspective {
|
||||
case protocol.PerspectiveClient:
|
||||
return h.handleMessageForClient(msgType)
|
||||
case protocol.PerspectiveServer:
|
||||
return h.handleMessageForServer(msgType)
|
||||
default:
|
||||
panic("")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
||||
var expected protocol.EncryptionLevel
|
||||
switch msgType {
|
||||
case typeClientHello,
|
||||
typeServerHello:
|
||||
expected = protocol.EncryptionInitial
|
||||
case typeEncryptedExtensions,
|
||||
typeCertificate,
|
||||
typeCertificateRequest,
|
||||
typeCertificateVerify,
|
||||
typeFinished:
|
||||
expected = protocol.EncryptionHandshake
|
||||
default:
|
||||
return fmt.Errorf("unexpected handshake message: %d", msgType)
|
||||
}
|
||||
if encLevel != expected {
|
||||
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
|
||||
switch msgType {
|
||||
case typeClientHello:
|
||||
select {
|
||||
case params := <-h.receivedTransportParams:
|
||||
h.handleParamsCallback(¶ms)
|
||||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
// get the handshake write key
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
// get the 1-RTT write key
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
// get the handshake read key
|
||||
// TODO: check that the initial stream doesn't have any more data
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case typeCertificate, typeCertificateVerify:
|
||||
// nothing to do
|
||||
return false
|
||||
case typeFinished:
|
||||
// get the 1-RTT read key
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
default:
|
||||
panic("unexpected handshake message")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
|
||||
switch msgType {
|
||||
case typeServerHello:
|
||||
// get the handshake read key
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case typeEncryptedExtensions:
|
||||
select {
|
||||
case params := <-h.receivedTransportParams:
|
||||
h.handleParamsCallback(¶ms)
|
||||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
return false
|
||||
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
|
||||
// nothing to do
|
||||
return false
|
||||
case typeFinished:
|
||||
// get the handshake write key
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
// While the order of these two is not defined by the TLS spec,
|
||||
// we have to do it on the same order as our TLS library does it.
|
||||
// get the handshake write key
|
||||
select {
|
||||
case <-h.receivedWriteKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
// get the 1-RTT read key
|
||||
select {
|
||||
case <-h.receivedReadKey:
|
||||
case <-h.handshakeErrChan:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
default:
|
||||
panic("unexpected handshake message: ")
|
||||
}
|
||||
}
|
||||
|
||||
// ReadHandshakeMessage is called by TLS.
|
||||
// It blocks until a new handshake message is available.
|
||||
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
|
||||
// TODO: add some error handling here (when the session is closed)
|
||||
msg, ok := <-h.messageChan
|
||||
if !ok {
|
||||
return nil, errors.New("error while handling the handshake message")
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
|
||||
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
|
||||
opener := newOpener(suite.AEAD(key, iv), iv)
|
||||
|
||||
switch h.readEncLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.readEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeOpener = opener
|
||||
h.logger.Debugf("Installed Handshake Read keys")
|
||||
case protocol.EncryptionHandshake:
|
||||
h.readEncLevel = protocol.Encryption1RTT
|
||||
h.opener = opener
|
||||
h.logger.Debugf("Installed 1-RTT Read keys")
|
||||
default:
|
||||
panic("unexpected read encryption level")
|
||||
}
|
||||
h.receivedReadKey <- struct{}{}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
|
||||
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
|
||||
sealer := newSealer(suite.AEAD(key, iv), iv)
|
||||
|
||||
switch h.writeEncLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.writeEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeSealer = sealer
|
||||
h.logger.Debugf("Installed Handshake Write keys")
|
||||
case protocol.EncryptionHandshake:
|
||||
h.writeEncLevel = protocol.Encryption1RTT
|
||||
h.sealer = sealer
|
||||
h.logger.Debugf("Installed 1-RTT Write keys")
|
||||
default:
|
||||
panic("unexpected write encryption level")
|
||||
}
|
||||
h.receivedWriteKey <- struct{}{}
|
||||
}
|
||||
|
||||
// WriteRecord is called when TLS writes data
|
||||
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
|
||||
switch h.writeEncLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
// assume that the first WriteRecord call contains the ClientHello
|
||||
n, err := h.initialStream.Write(p)
|
||||
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
|
||||
h.clientHelloWritten = true
|
||||
close(h.clientHelloWrittenChan)
|
||||
}
|
||||
return n, err
|
||||
case protocol.EncryptionHandshake:
|
||||
return h.handshakeStream.Write(p)
|
||||
default:
|
||||
return 0, fmt.Errorf("unexpected write encryption level: %s", h.writeEncLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||
if h.sealer != nil {
|
||||
return protocol.Encryption1RTT, h.sealer
|
||||
}
|
||||
if h.handshakeSealer != nil {
|
||||
return protocol.EncryptionHandshake, h.handshakeSealer
|
||||
}
|
||||
return protocol.EncryptionInitial, h.initialAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
|
||||
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String())
|
||||
|
||||
switch level {
|
||||
case protocol.EncryptionInitial:
|
||||
return h.initialAEAD, nil
|
||||
case protocol.EncryptionHandshake:
|
||||
if h.handshakeSealer == nil {
|
||||
return nil, errNoSealer
|
||||
}
|
||||
return h.handshakeSealer, nil
|
||||
case protocol.Encryption1RTT:
|
||||
if h.sealer == nil {
|
||||
return nil, errNoSealer
|
||||
}
|
||||
return h.sealer, nil
|
||||
default:
|
||||
return nil, errNoSealer
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
return h.initialAEAD.Open(dst, src, pn, ad)
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
if h.handshakeOpener == nil {
|
||||
return nil, errors.New("no handshake opener")
|
||||
}
|
||||
return h.handshakeOpener.Open(dst, src, pn, ad)
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
if h.opener == nil {
|
||||
return nil, errors.New("no 1-RTT opener")
|
||||
}
|
||||
return h.opener.Open(dst, src, pn, ad)
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ConnectionState() ConnectionState {
|
||||
// TODO: return the connection state
|
||||
return ConnectionState{}
|
||||
}
|
543
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
543
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
|
@ -1,543 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type cryptoSetupClient struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
hostname string
|
||||
connID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
initialVersion protocol.VersionNumber
|
||||
negotiatedVersions []protocol.VersionNumber
|
||||
|
||||
cryptoStream io.ReadWriter
|
||||
|
||||
serverConfig *serverConfigClient
|
||||
|
||||
stk []byte
|
||||
sno []byte
|
||||
nonc []byte
|
||||
proof []byte
|
||||
chloForSignature []byte
|
||||
lastSentCHLO []byte
|
||||
certManager crypto.CertManager
|
||||
|
||||
divNonceChan chan struct{}
|
||||
diversificationNonce []byte
|
||||
|
||||
clientHelloCounter int
|
||||
serverVerified bool // has the certificate chain and the proof already been verified
|
||||
keyDerivation QuicCryptoKeyDerivationFunction
|
||||
|
||||
receivedSecurePacket bool
|
||||
nullAEAD crypto.AEAD
|
||||
secureAEAD crypto.AEAD
|
||||
forwardSecureAEAD crypto.AEAD
|
||||
|
||||
paramsChan chan<- TransportParameters
|
||||
handshakeEvent chan<- struct{}
|
||||
|
||||
params *TransportParameters
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetupClient{}
|
||||
|
||||
var (
|
||||
errNoObitForClientNonce = errors.New("CryptoSetup BUG: No OBIT for client nonce available")
|
||||
errClientNonceAlreadyExists = errors.New("CryptoSetup BUG: A client nonce was already generated")
|
||||
errConflictingDiversificationNonces = errors.New("Received two different diversification nonces")
|
||||
)
|
||||
|
||||
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
|
||||
func NewCryptoSetupClient(
|
||||
cryptoStream io.ReadWriter,
|
||||
hostname string,
|
||||
connID protocol.ConnectionID,
|
||||
version protocol.VersionNumber,
|
||||
tlsConfig *tls.Config,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
handshakeEvent chan<- struct{},
|
||||
initialVersion protocol.VersionNumber,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
divNonceChan := make(chan struct{})
|
||||
cs := &cryptoSetupClient{
|
||||
cryptoStream: cryptoStream,
|
||||
hostname: hostname,
|
||||
connID: connID,
|
||||
version: version,
|
||||
certManager: crypto.NewCertManager(tlsConfig),
|
||||
params: params,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
nullAEAD: nullAEAD,
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
initialVersion: initialVersion,
|
||||
// The server might have sent greased versions in the Version Negotiation packet.
|
||||
// We need strip those from the list, since they won't be included in the handshake tag.
|
||||
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
|
||||
divNonceChan: divNonceChan,
|
||||
logger: logger,
|
||||
}
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||
messageChan := make(chan HandshakeMessage)
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
message, err := ParseHandshakeMessage(h.cryptoStream)
|
||||
if err != nil {
|
||||
errorChan <- qerr.Error(qerr.HandshakeFailed, err.Error())
|
||||
return
|
||||
}
|
||||
messageChan <- message
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if err := h.maybeUpgradeCrypto(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.mutex.RLock()
|
||||
sendCHLO := h.secureAEAD == nil
|
||||
h.mutex.RUnlock()
|
||||
if sendCHLO {
|
||||
if err := h.sendCHLO(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var message HandshakeMessage
|
||||
select {
|
||||
case <-h.divNonceChan:
|
||||
// there's no message to process, but we should try upgrading the crypto again
|
||||
continue
|
||||
case message = <-messageChan:
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Debugf("Got %s", message)
|
||||
switch message.Tag {
|
||||
case TagREJ:
|
||||
if err := h.handleREJMessage(message.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
case TagSHLO:
|
||||
params, err := h.handleSHLOMessage(message.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// blocks until the session has received the parameters
|
||||
h.paramsChan <- *params
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.handshakeEvent)
|
||||
default:
|
||||
return qerr.InvalidCryptoMessageType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
||||
var err error
|
||||
|
||||
if stk, ok := cryptoData[TagSTK]; ok {
|
||||
h.stk = stk
|
||||
}
|
||||
|
||||
if sno, ok := cryptoData[TagSNO]; ok {
|
||||
h.sno = sno
|
||||
}
|
||||
|
||||
// TODO: what happens if the server sends a different server config in two packets?
|
||||
if scfg, ok := cryptoData[TagSCFG]; ok {
|
||||
h.serverConfig, err = parseServerConfig(scfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.serverConfig.IsExpired() {
|
||||
return qerr.CryptoServerConfigExpired
|
||||
}
|
||||
|
||||
// now that we have a server config, we can use its OBIT value to generate a client nonce
|
||||
if len(h.nonc) == 0 {
|
||||
err = h.generateClientNonce()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if proof, ok := cryptoData[TagPROF]; ok {
|
||||
h.proof = proof
|
||||
h.chloForSignature = h.lastSentCHLO
|
||||
}
|
||||
|
||||
if crt, ok := cryptoData[TagCERT]; ok {
|
||||
err := h.certManager.SetData(crt)
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
|
||||
}
|
||||
|
||||
err = h.certManager.Verify(h.hostname)
|
||||
if err != nil {
|
||||
h.logger.Infof("Certificate validation failed: %s", err.Error())
|
||||
return qerr.ProofInvalid
|
||||
}
|
||||
}
|
||||
|
||||
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
|
||||
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
|
||||
if !validProof {
|
||||
h.logger.Infof("Server proof verification failed")
|
||||
return qerr.ProofInvalid
|
||||
}
|
||||
|
||||
h.serverVerified = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if !h.receivedSecurePacket {
|
||||
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
|
||||
}
|
||||
|
||||
if sno, ok := cryptoData[TagSNO]; ok {
|
||||
h.sno = sno
|
||||
}
|
||||
|
||||
serverPubs, ok := cryptoData[TagPUBS]
|
||||
if !ok {
|
||||
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
||||
}
|
||||
|
||||
verTag, ok := cryptoData[TagVER]
|
||||
if !ok {
|
||||
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
|
||||
}
|
||||
if !h.validateVersionList(verTag) {
|
||||
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
||||
}
|
||||
|
||||
nonce := append(h.nonc, h.sno...)
|
||||
|
||||
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leafCert := h.certManager.GetLeafCert()
|
||||
|
||||
h.forwardSecureAEAD, err = h.keyDerivation(
|
||||
true,
|
||||
ephermalSharedSecret,
|
||||
nonce,
|
||||
h.connID,
|
||||
h.lastSentCHLO,
|
||||
h.serverConfig.Get(),
|
||||
leafCert,
|
||||
nil,
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
|
||||
|
||||
params, err := readHelloMap(cryptoData)
|
||||
if err != nil {
|
||||
return nil, qerr.InvalidCryptoMessageParameter
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
|
||||
numNegotiatedVersions := len(h.negotiatedVersions)
|
||||
if numNegotiatedVersions == 0 {
|
||||
return true
|
||||
}
|
||||
if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions {
|
||||
return false
|
||||
}
|
||||
|
||||
b := bytes.NewReader(verTags)
|
||||
for i := 0; i < numNegotiatedVersions; i++ {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil { // should never occur, since the length was already checked
|
||||
return false
|
||||
}
|
||||
if protocol.VersionNumber(v) != h.negotiatedVersions[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
if h.forwardSecureAEAD != nil {
|
||||
data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
return data, protocol.EncryptionForwardSecure, nil
|
||||
}
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
|
||||
if h.secureAEAD != nil {
|
||||
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
||||
h.receivedSecurePacket = true
|
||||
return data, protocol.EncryptionSecure, nil
|
||||
}
|
||||
if h.receivedSecurePacket {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
}
|
||||
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err != nil {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
return res, protocol.EncryptionUnencrypted, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
if h.forwardSecureAEAD != nil {
|
||||
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
|
||||
} else if h.secureAEAD != nil {
|
||||
return protocol.EncryptionSecure, h.secureAEAD
|
||||
} else {
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
switch encLevel {
|
||||
case protocol.EncryptionUnencrypted:
|
||||
return h.nullAEAD, nil
|
||||
case protocol.EncryptionSecure:
|
||||
if h.secureAEAD == nil {
|
||||
return nil, errors.New("CryptoSetupClient: no secureAEAD")
|
||||
}
|
||||
return h.secureAEAD, nil
|
||||
case protocol.EncryptionForwardSecure:
|
||||
if h.forwardSecureAEAD == nil {
|
||||
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
|
||||
}
|
||||
return h.forwardSecureAEAD, nil
|
||||
}
|
||||
return nil, errors.New("CryptoSetupClient: no encryption level specified")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
return ConnectionState{
|
||||
HandshakeComplete: h.forwardSecureAEAD != nil,
|
||||
PeerCertificates: h.certManager.GetChain(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
|
||||
h.mutex.Lock()
|
||||
if len(h.diversificationNonce) > 0 {
|
||||
defer h.mutex.Unlock()
|
||||
if !bytes.Equal(h.diversificationNonce, divNonce) {
|
||||
return errConflictingDiversificationNonces
|
||||
}
|
||||
return nil
|
||||
}
|
||||
h.diversificationNonce = divNonce
|
||||
h.mutex.Unlock()
|
||||
h.divNonceChan <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sendCHLO() error {
|
||||
h.clientHelloCounter++
|
||||
if h.clientHelloCounter > protocol.MaxClientHellos {
|
||||
return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos))
|
||||
}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
tags, err := h.getTags()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.addPadding(tags)
|
||||
message := HandshakeMessage{
|
||||
Tag: TagCHLO,
|
||||
Data: tags,
|
||||
}
|
||||
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
message.Write(b)
|
||||
|
||||
_, err = h.cryptoStream.Write(b.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.lastSentCHLO = b.Bytes()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
||||
tags := h.params.getHelloMap()
|
||||
tags[TagSNI] = []byte(h.hostname)
|
||||
tags[TagPDMD] = []byte("X509")
|
||||
|
||||
ccs := h.certManager.GetCommonCertificateHashes()
|
||||
if len(ccs) > 0 {
|
||||
tags[TagCCS] = ccs
|
||||
}
|
||||
|
||||
versionTag := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion))
|
||||
tags[TagVER] = versionTag
|
||||
|
||||
if len(h.stk) > 0 {
|
||||
tags[TagSTK] = h.stk
|
||||
}
|
||||
if len(h.sno) > 0 {
|
||||
tags[TagSNO] = h.sno
|
||||
}
|
||||
|
||||
if h.serverConfig != nil {
|
||||
tags[TagSCID] = h.serverConfig.ID
|
||||
|
||||
leafCert := h.certManager.GetLeafCert()
|
||||
if leafCert != nil {
|
||||
certHash, _ := h.certManager.GetLeafCertHash()
|
||||
xlct := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(xlct, certHash)
|
||||
|
||||
tags[TagNONC] = h.nonc
|
||||
tags[TagXLCT] = xlct
|
||||
tags[TagKEXS] = []byte("C255")
|
||||
tags[TagAEAD] = []byte("AESG")
|
||||
tags[TagPUBS] = h.serverConfig.kex.PublicKey() // TODO: check if 3 bytes need to be prepended
|
||||
}
|
||||
}
|
||||
|
||||
return tags, nil
|
||||
}
|
||||
|
||||
// add a TagPAD to a tagMap, such that the total size will be bigger than the ClientHelloMinimumSize
|
||||
func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
|
||||
var size int
|
||||
for _, tag := range tags {
|
||||
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
||||
}
|
||||
paddingSize := protocol.MinClientHelloSize - size
|
||||
if paddingSize > 0 {
|
||||
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
|
||||
if !h.serverVerified {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
leafCert := h.certManager.GetLeafCert()
|
||||
if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) {
|
||||
var err error
|
||||
var nonce []byte
|
||||
if h.sno == nil {
|
||||
nonce = h.nonc
|
||||
} else {
|
||||
nonce = append(h.nonc, h.sno...)
|
||||
}
|
||||
|
||||
h.secureAEAD, err = h.keyDerivation(
|
||||
false,
|
||||
h.serverConfig.sharedSecret,
|
||||
nonce,
|
||||
h.connID,
|
||||
h.lastSentCHLO,
|
||||
h.serverConfig.Get(),
|
||||
leafCert,
|
||||
h.diversificationNonce,
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for secure encryption.")
|
||||
h.handshakeEvent <- struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) generateClientNonce() error {
|
||||
if len(h.nonc) > 0 {
|
||||
return errClientNonceAlreadyExists
|
||||
}
|
||||
|
||||
nonc := make([]byte, 32)
|
||||
binary.BigEndian.PutUint32(nonc, uint32(time.Now().Unix()))
|
||||
|
||||
if len(h.serverConfig.obit) != 8 {
|
||||
return errNoObitForClientNonce
|
||||
}
|
||||
|
||||
copy(nonc[4:12], h.serverConfig.obit)
|
||||
|
||||
_, err := rand.Read(nonc[12:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.nonc = nonc
|
||||
return nil
|
||||
}
|
467
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
467
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
|
@ -1,467 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// QuicCryptoKeyDerivationFunction is used for key derivation
|
||||
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
|
||||
|
||||
// KeyExchangeFunction is used to make a new KEX
|
||||
type KeyExchangeFunction func() (crypto.KeyExchange, error)
|
||||
|
||||
// The CryptoSetupServer handles all things crypto for the Session
|
||||
type cryptoSetupServer struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
connID protocol.ConnectionID
|
||||
remoteAddr net.Addr
|
||||
scfg *ServerConfig
|
||||
diversificationNonce []byte
|
||||
|
||||
version protocol.VersionNumber
|
||||
supportedVersions []protocol.VersionNumber
|
||||
|
||||
acceptSTKCallback func(net.Addr, *Cookie) bool
|
||||
|
||||
nullAEAD crypto.AEAD
|
||||
secureAEAD crypto.AEAD
|
||||
forwardSecureAEAD crypto.AEAD
|
||||
receivedForwardSecurePacket bool
|
||||
receivedSecurePacket bool
|
||||
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
|
||||
|
||||
receivedParams bool
|
||||
paramsChan chan<- TransportParameters
|
||||
handshakeEvent chan<- struct{}
|
||||
|
||||
keyDerivation QuicCryptoKeyDerivationFunction
|
||||
keyExchange KeyExchangeFunction
|
||||
|
||||
cryptoStream io.ReadWriter
|
||||
|
||||
params *TransportParameters
|
||||
|
||||
sni string // need to fill out the ConnectionState
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetupServer{}
|
||||
|
||||
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
|
||||
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
|
||||
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
|
||||
|
||||
// NewCryptoSetup creates a new CryptoSetup instance for a server
|
||||
func NewCryptoSetup(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
remoteAddr net.Addr,
|
||||
version protocol.VersionNumber,
|
||||
divNonce []byte,
|
||||
scfg *ServerConfig,
|
||||
params *TransportParameters,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
acceptSTK func(net.Addr, *Cookie) bool,
|
||||
paramsChan chan<- TransportParameters,
|
||||
handshakeEvent chan<- struct{},
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cryptoSetupServer{
|
||||
cryptoStream: cryptoStream,
|
||||
connID: connID,
|
||||
remoteAddr: remoteAddr,
|
||||
version: version,
|
||||
supportedVersions: supportedVersions,
|
||||
diversificationNonce: divNonce,
|
||||
scfg: scfg,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
params: params,
|
||||
acceptSTKCallback: acceptSTK,
|
||||
sentSHLO: make(chan struct{}),
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleCryptoStream reads and writes messages on the crypto stream
|
||||
func (h *cryptoSetupServer) HandleCryptoStream() error {
|
||||
for {
|
||||
var chloData bytes.Buffer
|
||||
message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))
|
||||
if err != nil {
|
||||
return qerr.HandshakeFailed
|
||||
}
|
||||
if message.Tag != TagCHLO {
|
||||
return qerr.InvalidCryptoMessageType
|
||||
}
|
||||
|
||||
h.logger.Debugf("Got %s", message)
|
||||
done, err := h.handleMessage(chloData.Bytes(), message.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) {
|
||||
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
|
||||
return false, ErrNSTPExperiment
|
||||
}
|
||||
|
||||
sniSlice, ok := cryptoData[TagSNI]
|
||||
if !ok {
|
||||
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
|
||||
}
|
||||
sni := string(sniSlice)
|
||||
if sni == "" {
|
||||
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
|
||||
}
|
||||
h.sni = sni
|
||||
|
||||
// prevent version downgrade attacks
|
||||
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
|
||||
verSlice, ok := cryptoData[TagVER]
|
||||
if !ok {
|
||||
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")
|
||||
}
|
||||
if len(verSlice) != 4 {
|
||||
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")
|
||||
}
|
||||
ver := protocol.VersionNumber(binary.BigEndian.Uint32(verSlice))
|
||||
// If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack.
|
||||
if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) {
|
||||
return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
||||
}
|
||||
|
||||
var reply []byte
|
||||
var err error
|
||||
|
||||
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
params, err := readHelloMap(cryptoData)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// blocks until the session has received the parameters
|
||||
if !h.receivedParams {
|
||||
h.receivedParams = true
|
||||
h.paramsChan <- *params
|
||||
}
|
||||
|
||||
if !h.isInchoateCHLO(cryptoData, certUncompressed) {
|
||||
// We have a CHLO with a proper server config ID, do a 0-RTT handshake
|
||||
reply, err = h.handleCHLO(sni, chloData, cryptoData)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if _, err := h.cryptoStream.Write(reply); err != nil {
|
||||
return false, err
|
||||
}
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.sentSHLO)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// We have an inchoate or non-matching CHLO, we now send a rejection
|
||||
reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, err = h.cryptoStream.Write(reply)
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Open a message
|
||||
func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
if h.forwardSecureAEAD != nil {
|
||||
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
|
||||
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
|
||||
h.receivedForwardSecurePacket = true
|
||||
// wait for the send on the handshakeEvent chan
|
||||
<-h.sentSHLO
|
||||
close(h.handshakeEvent)
|
||||
}
|
||||
return res, protocol.EncryptionForwardSecure, nil
|
||||
}
|
||||
if h.receivedForwardSecurePacket {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
}
|
||||
if h.secureAEAD != nil {
|
||||
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
||||
h.receivedSecurePacket = true
|
||||
return res, protocol.EncryptionSecure, nil
|
||||
}
|
||||
if h.receivedSecurePacket {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
}
|
||||
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err != nil {
|
||||
return res, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
return res, protocol.EncryptionUnencrypted, err
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
if h.forwardSecureAEAD != nil {
|
||||
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
|
||||
}
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
if h.secureAEAD != nil {
|
||||
return protocol.EncryptionSecure, h.secureAEAD
|
||||
}
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
switch encLevel {
|
||||
case protocol.EncryptionUnencrypted:
|
||||
return h.nullAEAD, nil
|
||||
case protocol.EncryptionSecure:
|
||||
if h.secureAEAD == nil {
|
||||
return nil, errors.New("CryptoSetupServer: no secureAEAD")
|
||||
}
|
||||
return h.secureAEAD, nil
|
||||
case protocol.EncryptionForwardSecure:
|
||||
if h.forwardSecureAEAD == nil {
|
||||
return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD")
|
||||
}
|
||||
return h.forwardSecureAEAD, nil
|
||||
}
|
||||
return nil, errors.New("CryptoSetupServer: no encryption level specified")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool {
|
||||
if _, ok := cryptoData[TagPUBS]; !ok {
|
||||
return true
|
||||
}
|
||||
scid, ok := cryptoData[TagSCID]
|
||||
if !ok || !bytes.Equal(h.scfg.ID, scid) {
|
||||
return true
|
||||
}
|
||||
xlctTag, ok := cryptoData[TagXLCT]
|
||||
if !ok || len(xlctTag) != 8 {
|
||||
return true
|
||||
}
|
||||
xlct := binary.LittleEndian.Uint64(xlctTag)
|
||||
if crypto.HashCert(cert) != xlct {
|
||||
return true
|
||||
}
|
||||
return !h.acceptSTK(cryptoData[TagSTK])
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
|
||||
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
h.logger.Debugf("STK invalid: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
return h.acceptSTKCallback(h.remoteAddr, stk)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
||||
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
replyMap := map[Tag][]byte{
|
||||
TagSCFG: h.scfg.Get(),
|
||||
TagSTK: token,
|
||||
TagSVID: []byte("quic-go"),
|
||||
}
|
||||
|
||||
if h.acceptSTK(cryptoData[TagSTK]) {
|
||||
proof, err := h.scfg.Sign(sni, chlo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commonSetHashes := cryptoData[TagCCS]
|
||||
cachedCertsHashes := cryptoData[TagCCRT]
|
||||
|
||||
certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Token was valid, send more details
|
||||
replyMap[TagPROF] = proof
|
||||
replyMap[TagCERT] = certCompressed
|
||||
}
|
||||
|
||||
message := HandshakeMessage{
|
||||
Tag: TagREJ,
|
||||
Data: replyMap,
|
||||
}
|
||||
|
||||
var serverReply bytes.Buffer
|
||||
message.Write(&serverReply)
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
return serverReply.Bytes(), nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
||||
// We have a CHLO matching our server config, we can continue with the 0-RTT handshake
|
||||
sharedSecret, err := h.scfg.kex.CalculateSharedKey(cryptoData[TagPUBS])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverNonce := make([]byte, 32)
|
||||
if _, err = rand.Read(serverNonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientNonce := cryptoData[TagNONC]
|
||||
err = h.validateClientNonce(clientNonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aead := cryptoData[TagAEAD]
|
||||
if !bytes.Equal(aead, []byte("AESG")) {
|
||||
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
|
||||
}
|
||||
|
||||
kexs := cryptoData[TagKEXS]
|
||||
if !bytes.Equal(kexs, []byte("C255")) {
|
||||
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
|
||||
}
|
||||
|
||||
h.secureAEAD, err = h.keyDerivation(
|
||||
false,
|
||||
sharedSecret,
|
||||
clientNonce,
|
||||
h.connID,
|
||||
data,
|
||||
h.scfg.Get(),
|
||||
certUncompressed,
|
||||
h.diversificationNonce,
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for secure encryption.")
|
||||
h.handshakeEvent <- struct{}{}
|
||||
|
||||
// Generate a new curve instance to derive the forward secure key
|
||||
var fsNonce bytes.Buffer
|
||||
fsNonce.Write(clientNonce)
|
||||
fsNonce.Write(serverNonce)
|
||||
ephermalKex, err := h.keyExchange()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.forwardSecureAEAD, err = h.keyDerivation(
|
||||
true,
|
||||
ephermalSharedSecret,
|
||||
fsNonce.Bytes(),
|
||||
h.connID,
|
||||
data,
|
||||
h.scfg.Get(),
|
||||
certUncompressed,
|
||||
nil,
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
|
||||
|
||||
replyMap := h.params.getHelloMap()
|
||||
// add crypto parameters
|
||||
verTag := &bytes.Buffer{}
|
||||
for _, v := range h.supportedVersions {
|
||||
utils.BigEndian.WriteUint32(verTag, uint32(v))
|
||||
}
|
||||
replyMap[TagPUBS] = ephermalKex.PublicKey()
|
||||
replyMap[TagSNO] = serverNonce
|
||||
replyMap[TagVER] = verTag.Bytes()
|
||||
|
||||
// note that the SHLO *has* to fit into one packet
|
||||
message := HandshakeMessage{
|
||||
Tag: TagSHLO,
|
||||
Data: replyMap,
|
||||
}
|
||||
var reply bytes.Buffer
|
||||
message.Write(&reply)
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
return reply.Bytes(), nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
return ConnectionState{
|
||||
ServerName: h.sni,
|
||||
HandshakeComplete: h.receivedForwardSecurePacket,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
||||
if len(nonce) != 32 {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
|
||||
}
|
||||
if !bytes.Equal(nonce[4:12], h.scfg.obit) {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")
|
||||
}
|
||||
return nil
|
||||
}
|
163
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
163
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
|
@ -1,163 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// KeyDerivationFunction is used for key derivation
|
||||
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
|
||||
|
||||
type cryptoSetupTLS struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
perspective protocol.Perspective
|
||||
|
||||
keyDerivation KeyDerivationFunction
|
||||
nullAEAD crypto.AEAD
|
||||
aead crypto.AEAD
|
||||
|
||||
tls mintTLS
|
||||
conn *cryptoStreamConn
|
||||
handshakeEvent chan<- struct{}
|
||||
}
|
||||
|
||||
var _ CryptoSetupTLS = &cryptoSetupTLS{}
|
||||
|
||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||
func NewCryptoSetupTLSServer(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
config *mint.Config,
|
||||
handshakeEvent chan<- struct{},
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetupTLS, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := newCryptoStreamConn(cryptoStream)
|
||||
tls := mint.Server(conn, config)
|
||||
return &cryptoSetupTLS{
|
||||
tls: tls,
|
||||
conn: conn,
|
||||
nullAEAD: nullAEAD,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
||||
func NewCryptoSetupTLSClient(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
config *mint.Config,
|
||||
handshakeEvent chan<- struct{},
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetupTLS, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := newCryptoStreamConn(cryptoStream)
|
||||
tls := mint.Client(conn, config)
|
||||
return &cryptoSetupTLS{
|
||||
tls: tls,
|
||||
conn: conn,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
handshakeEvent: handshakeEvent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||
for {
|
||||
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
|
||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||
}
|
||||
state := h.tls.ConnectionState().HandshakeState
|
||||
if err := h.conn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
if state == mint.StateClientConnected || state == mint.StateServerConnected {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
aead, err := h.keyDerivation(h.tls, h.perspective)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.mutex.Lock()
|
||||
h.aead = aead
|
||||
h.mutex.Unlock()
|
||||
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.handshakeEvent)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
if h.aead == nil {
|
||||
return nil, errors.New("no 1-RTT sealer")
|
||||
}
|
||||
return h.aead.Open(dst, src, packetNumber, associatedData)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
if h.aead != nil {
|
||||
return protocol.EncryptionForwardSecure, h.aead
|
||||
}
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
||||
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", encLevel.String())
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
switch encLevel {
|
||||
case protocol.EncryptionUnencrypted:
|
||||
return h.nullAEAD, nil
|
||||
case protocol.EncryptionForwardSecure:
|
||||
if h.aead == nil {
|
||||
return nil, errNoSealer
|
||||
}
|
||||
return h.aead, nil
|
||||
default:
|
||||
return nil, errNoSealer
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
mintConnState := h.tls.ConnectionState()
|
||||
return ConnectionState{
|
||||
// TODO: set the ServerName, once mint exports it
|
||||
HandshakeComplete: h.aead != nil,
|
||||
PeerCertificates: mintConnState.PeerCertificates,
|
||||
}
|
||||
}
|
69
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
generated
vendored
69
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
generated
vendored
|
@ -1,69 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type cryptoStreamConn struct {
|
||||
buffer *bytes.Buffer
|
||||
stream io.ReadWriter
|
||||
}
|
||||
|
||||
var _ net.Conn = &cryptoStreamConn{}
|
||||
|
||||
func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn {
|
||||
return &cryptoStreamConn{
|
||||
stream: stream,
|
||||
buffer: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cryptoStreamConn) Read(b []byte) (int, error) {
|
||||
return c.stream.Read(b)
|
||||
}
|
||||
|
||||
func (c *cryptoStreamConn) Write(p []byte) (int, error) {
|
||||
return c.buffer.Write(p)
|
||||
}
|
||||
|
||||
func (c *cryptoStreamConn) Flush() error {
|
||||
if c.buffer.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := c.stream.Write(c.buffer.Bytes())
|
||||
c.buffer.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
// Close is not implemented
|
||||
func (c *cryptoStreamConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr is not implemented
|
||||
func (c *cryptoStreamConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteAddr is not implemented
|
||||
func (c *cryptoStreamConn) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline is not implemented
|
||||
func (c *cryptoStreamConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline is not implemented
|
||||
func (c *cryptoStreamConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeadline is not implemented
|
||||
func (c *cryptoStreamConn) SetDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
||||
kexLifetime = protocol.EphermalKeyLifetime
|
||||
kexCurrent crypto.KeyExchange
|
||||
kexCurrentTime time.Time
|
||||
kexMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// getEphermalKEX returns the currently active KEX, which changes every protocol.EphermalKeyLifetime
|
||||
// See the explanation from the QUIC crypto doc:
|
||||
//
|
||||
// A single connection is the usual scope for forward security, but the security
|
||||
// difference between an ephemeral key used for a single connection, and one
|
||||
// used for all connections for 60 seconds is negligible. Thus we can amortise
|
||||
// the Diffie-Hellman key generation at the server over all the connections in a
|
||||
// small time span.
|
||||
func getEphermalKEX() (crypto.KeyExchange, error) {
|
||||
kexMutex.RLock()
|
||||
res := kexCurrent
|
||||
t := kexCurrentTime
|
||||
kexMutex.RUnlock()
|
||||
if res != nil && time.Since(t) < kexLifetime {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
kexMutex.Lock()
|
||||
defer kexMutex.Unlock()
|
||||
// Check if still unfulfilled
|
||||
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
|
||||
kex, err := crypto.NewCurve25519KEX()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kexCurrent = kex
|
||||
kexCurrentTime = time.Now()
|
||||
return kexCurrent, nil
|
||||
}
|
||||
return kexCurrent, nil
|
||||
}
|
137
vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go
generated
vendored
137
vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go
generated
vendored
|
@ -1,137 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// A HandshakeMessage is a handshake message
|
||||
type HandshakeMessage struct {
|
||||
Tag Tag
|
||||
Data map[Tag][]byte
|
||||
}
|
||||
|
||||
var _ fmt.Stringer = &HandshakeMessage{}
|
||||
|
||||
// ParseHandshakeMessage reads a crypto message
|
||||
func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) {
|
||||
slice4 := make([]byte, 4)
|
||||
|
||||
if _, err := io.ReadFull(r, slice4); err != nil {
|
||||
return HandshakeMessage{}, err
|
||||
}
|
||||
messageTag := Tag(binary.LittleEndian.Uint32(slice4))
|
||||
|
||||
if _, err := io.ReadFull(r, slice4); err != nil {
|
||||
return HandshakeMessage{}, err
|
||||
}
|
||||
nPairs := binary.LittleEndian.Uint32(slice4)
|
||||
|
||||
if nPairs > protocol.CryptoMaxParams {
|
||||
return HandshakeMessage{}, qerr.CryptoTooManyEntries
|
||||
}
|
||||
|
||||
index := make([]byte, nPairs*8)
|
||||
if _, err := io.ReadFull(r, index); err != nil {
|
||||
return HandshakeMessage{}, err
|
||||
}
|
||||
|
||||
resultMap := map[Tag][]byte{}
|
||||
|
||||
var dataStart uint32
|
||||
for indexPos := 0; indexPos < int(nPairs)*8; indexPos += 8 {
|
||||
tag := Tag(binary.LittleEndian.Uint32(index[indexPos : indexPos+4]))
|
||||
dataEnd := binary.LittleEndian.Uint32(index[indexPos+4 : indexPos+8])
|
||||
|
||||
dataLen := dataEnd - dataStart
|
||||
if dataLen > protocol.CryptoParameterMaxLength {
|
||||
return HandshakeMessage{}, qerr.Error(qerr.CryptoInvalidValueLength, "value too long")
|
||||
}
|
||||
|
||||
data := make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(r, data); err != nil {
|
||||
return HandshakeMessage{}, err
|
||||
}
|
||||
|
||||
resultMap[tag] = data
|
||||
dataStart = dataEnd
|
||||
}
|
||||
|
||||
return HandshakeMessage{
|
||||
Tag: messageTag,
|
||||
Data: resultMap}, nil
|
||||
}
|
||||
|
||||
// Write writes a crypto message
|
||||
func (h HandshakeMessage) Write(b *bytes.Buffer) {
|
||||
data := h.Data
|
||||
utils.LittleEndian.WriteUint32(b, uint32(h.Tag))
|
||||
utils.LittleEndian.WriteUint16(b, uint16(len(data)))
|
||||
utils.LittleEndian.WriteUint16(b, 0)
|
||||
|
||||
// Save current position in the buffer, so that we can update the index in-place later
|
||||
indexStart := b.Len()
|
||||
|
||||
indexData := make([]byte, 8*len(data))
|
||||
b.Write(indexData) // Will be updated later
|
||||
|
||||
offset := uint32(0)
|
||||
for i, t := range h.getTagsSorted() {
|
||||
v := data[t]
|
||||
b.Write(v)
|
||||
offset += uint32(len(v))
|
||||
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
|
||||
binary.LittleEndian.PutUint32(indexData[i*8+4:], offset)
|
||||
}
|
||||
|
||||
// Now we write the index data for real
|
||||
copy(b.Bytes()[indexStart:], indexData)
|
||||
}
|
||||
|
||||
func (h *HandshakeMessage) getTagsSorted() []Tag {
|
||||
tags := make([]Tag, len(h.Data))
|
||||
i := 0
|
||||
for t := range h.Data {
|
||||
tags[i] = t
|
||||
i++
|
||||
}
|
||||
sort.Slice(tags, func(i, j int) bool {
|
||||
return tags[i] < tags[j]
|
||||
})
|
||||
return tags
|
||||
}
|
||||
|
||||
func (h HandshakeMessage) String() string {
|
||||
var pad string
|
||||
res := tagToString(h.Tag) + ":\n"
|
||||
for _, tag := range h.getTagsSorted() {
|
||||
if tag == TagPAD {
|
||||
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
|
||||
} else {
|
||||
res += fmt.Sprintf("\t%s: %#v\n", tagToString(tag), string(h.Data[tag]))
|
||||
}
|
||||
}
|
||||
|
||||
if len(pad) > 0 {
|
||||
res += pad
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func tagToString(tag Tag) string {
|
||||
b := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(b, uint32(tag))
|
||||
for i := range b {
|
||||
if b[i] == 0 {
|
||||
b[i] = ' '
|
||||
}
|
||||
}
|
||||
return string(b)
|
||||
}
|
|
@ -2,54 +2,43 @@ package handshake
|
|||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"io"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
// Opener opens a packet
|
||||
type Opener interface {
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// Sealer seals a packet
|
||||
type Sealer interface {
|
||||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||
Overhead() int
|
||||
}
|
||||
|
||||
// mintTLS combines some methods needed to interact with mint.
|
||||
type mintTLS interface {
|
||||
crypto.TLSExporter
|
||||
Handshake() mint.Alert
|
||||
// A tlsExtensionHandler sends and received the QUIC TLS extension.
|
||||
type tlsExtensionHandler interface {
|
||||
GetExtensions(msgType uint8) []qtls.Extension
|
||||
ReceivedExtensions(msgType uint8, exts []qtls.Extension) error
|
||||
}
|
||||
|
||||
// A TLSExtensionHandler sends and received the QUIC TLS extension.
|
||||
// It provides the parameters sent by the peer on a channel.
|
||||
type TLSExtensionHandler interface {
|
||||
Send(mint.HandshakeType, *mint.ExtensionList) error
|
||||
Receive(mint.HandshakeType, *mint.ExtensionList) error
|
||||
GetPeerParams() <-chan TransportParameters
|
||||
}
|
||||
// CryptoSetup handles the handshake and protecting / unprotecting packets
|
||||
type CryptoSetup interface {
|
||||
RunHandshake() error
|
||||
io.Closer
|
||||
|
||||
type baseCryptoSetup interface {
|
||||
HandleCryptoStream() error
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
ConnectionState() ConnectionState
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
||||
}
|
||||
|
||||
// CryptoSetup is the crypto setup used by gQUIC
|
||||
type CryptoSetup interface {
|
||||
baseCryptoSetup
|
||||
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
}
|
||||
|
||||
// CryptoSetupTLS is the crypto setup used by IETF QUIC
|
||||
type CryptoSetupTLS interface {
|
||||
baseCryptoSetup
|
||||
|
||||
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about the QUIC connection.
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
package handshake
|
||||
|
||||
//go:generate sh -c "../mockgen_internal.sh handshake mock_mint_tls_test.go github.com/lucas-clemente/quic-go/internal/handshake mintTLS"
|
48
vendor/github.com/lucas-clemente/quic-go/internal/handshake/qtls.go
generated
vendored
Normal file
48
vendor/github.com/lucas-clemente/quic-go/internal/handshake/qtls.go
generated
vendored
Normal file
|
@ -0,0 +1,48 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
|
||||
if c == nil {
|
||||
c = &tls.Config{}
|
||||
}
|
||||
// QUIC requires TLS 1.3 or newer
|
||||
if c.MinVersion < qtls.VersionTLS13 {
|
||||
c.MinVersion = qtls.VersionTLS13
|
||||
}
|
||||
if c.MaxVersion < qtls.VersionTLS13 {
|
||||
c.MaxVersion = qtls.VersionTLS13
|
||||
}
|
||||
return &qtls.Config{
|
||||
Rand: c.Rand,
|
||||
Time: c.Time,
|
||||
Certificates: c.Certificates,
|
||||
NameToCertificate: c.NameToCertificate,
|
||||
// TODO: make GetCertificate work
|
||||
// GetCertificate: c.GetCertificate,
|
||||
GetClientCertificate: c.GetClientCertificate,
|
||||
// TODO: make GetConfigForClient work
|
||||
// GetConfigForClient: c.GetConfigForClient,
|
||||
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||
RootCAs: c.RootCAs,
|
||||
NextProtos: c.NextProtos,
|
||||
ServerName: c.ServerName,
|
||||
ClientAuth: c.ClientAuth,
|
||||
ClientCAs: c.ClientCAs,
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
CipherSuites: c.CipherSuites,
|
||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||
SessionTicketKey: c.SessionTicketKey,
|
||||
MinVersion: c.MinVersion,
|
||||
MaxVersion: c.MaxVersion,
|
||||
CurvePreferences: c.CurvePreferences,
|
||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||
Renegotiation: c.Renegotiation,
|
||||
KeyLogWriter: c.KeyLogWriter,
|
||||
}
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
)
|
||||
|
||||
// ServerConfig is a server config
|
||||
type ServerConfig struct {
|
||||
kex crypto.KeyExchange
|
||||
certChain crypto.CertChain
|
||||
ID []byte
|
||||
obit []byte
|
||||
cookieGenerator *CookieGenerator
|
||||
}
|
||||
|
||||
// NewServerConfig creates a new server config
|
||||
func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*ServerConfig, error) {
|
||||
id := make([]byte, 16)
|
||||
_, err := rand.Read(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
obit := make([]byte, 8)
|
||||
if _, err = rand.Read(obit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookieGenerator, err := NewCookieGenerator()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ServerConfig{
|
||||
kex: kex,
|
||||
certChain: certChain,
|
||||
ID: id,
|
||||
obit: obit,
|
||||
cookieGenerator: cookieGenerator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get the server config binary representation
|
||||
func (s *ServerConfig) Get() []byte {
|
||||
var serverConfig bytes.Buffer
|
||||
msg := HandshakeMessage{
|
||||
Tag: TagSCFG,
|
||||
Data: map[Tag][]byte{
|
||||
TagSCID: s.ID,
|
||||
TagKEXS: []byte("C255"),
|
||||
TagAEAD: []byte("AESG"),
|
||||
TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...),
|
||||
TagOBIT: s.obit,
|
||||
TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
},
|
||||
}
|
||||
msg.Write(&serverConfig)
|
||||
return serverConfig.Bytes()
|
||||
}
|
||||
|
||||
// Sign the server config and CHLO with the server's keyData
|
||||
func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) {
|
||||
return s.certChain.SignServerProof(sni, chlo, s.Get())
|
||||
}
|
||||
|
||||
// GetCertsCompressed returns the certificate data
|
||||
func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) {
|
||||
return s.certChain.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
|
||||
}
|
184
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go
generated
vendored
184
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go
generated
vendored
|
@ -1,184 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type serverConfigClient struct {
|
||||
raw []byte
|
||||
ID []byte
|
||||
obit []byte
|
||||
expiry time.Time
|
||||
|
||||
kex crypto.KeyExchange
|
||||
sharedSecret []byte
|
||||
}
|
||||
|
||||
var (
|
||||
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG")
|
||||
)
|
||||
|
||||
// parseServerConfig parses a server config
|
||||
func parseServerConfig(data []byte) (*serverConfigClient, error) {
|
||||
message, err := ParseHandshakeMessage(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if message.Tag != TagSCFG {
|
||||
return nil, errMessageNotServerConfig
|
||||
}
|
||||
|
||||
scfg := &serverConfigClient{raw: data}
|
||||
err = scfg.parseValues(message.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return scfg, nil
|
||||
}
|
||||
|
||||
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
|
||||
// SCID
|
||||
scfgID, ok := tagMap[TagSCID]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID")
|
||||
}
|
||||
if len(scfgID) != 16 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID")
|
||||
}
|
||||
s.ID = scfgID
|
||||
|
||||
// KEXS
|
||||
// TODO: setup Key Exchange
|
||||
kexs, ok := tagMap[TagKEXS]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS")
|
||||
}
|
||||
if len(kexs)%4 != 0 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS")
|
||||
}
|
||||
c255Foundat := -1
|
||||
|
||||
for i := 0; i < len(kexs)/4; i++ {
|
||||
if bytes.Equal(kexs[4*i:4*i+4], []byte("C255")) {
|
||||
c255Foundat = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if c255Foundat < 0 {
|
||||
return qerr.Error(qerr.CryptoNoSupport, "KEXS: Could not find C255, other key exchanges are not supported")
|
||||
}
|
||||
|
||||
// AEAD
|
||||
aead, ok := tagMap[TagAEAD]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD")
|
||||
}
|
||||
if len(aead)%4 != 0 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD")
|
||||
}
|
||||
var aesgFound bool
|
||||
for i := 0; i < len(aead)/4; i++ {
|
||||
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) {
|
||||
aesgFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !aesgFound {
|
||||
return qerr.Error(qerr.CryptoNoSupport, "AEAD")
|
||||
}
|
||||
|
||||
// PUBS
|
||||
pubs, ok := tagMap[TagPUBS]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
||||
}
|
||||
|
||||
var pubsKexs []struct {
|
||||
Length uint32
|
||||
Value []byte
|
||||
}
|
||||
var lastLen uint32
|
||||
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
|
||||
// the PUBS value is always prepended by 3 byte little endian length field
|
||||
|
||||
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
|
||||
}
|
||||
if lastLen == 0 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||
}
|
||||
|
||||
if i+3+int(lastLen) > len(pubs) {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||
}
|
||||
|
||||
pubsKexs = append(pubsKexs, struct {
|
||||
Length uint32
|
||||
Value []byte
|
||||
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
|
||||
}
|
||||
|
||||
if c255Foundat >= len(pubsKexs) {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
|
||||
}
|
||||
|
||||
if pubsKexs[c255Foundat].Length != 32 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||
}
|
||||
|
||||
var err error
|
||||
s.kex, err = crypto.NewCurve25519KEX()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// OBIT
|
||||
obit, ok := tagMap[TagOBIT]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT")
|
||||
}
|
||||
if len(obit) != 8 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT")
|
||||
}
|
||||
s.obit = obit
|
||||
|
||||
// EXPY
|
||||
expy, ok := tagMap[TagEXPY]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY")
|
||||
}
|
||||
if len(expy) != 8 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY")
|
||||
}
|
||||
// make sure that the value doesn't overflow an int64
|
||||
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
|
||||
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2)
|
||||
s.expiry = time.Unix(int64(expyTimestamp), 0)
|
||||
|
||||
// TODO: implement VER
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serverConfigClient) IsExpired() bool {
|
||||
return s.expiry.Before(time.Now())
|
||||
}
|
||||
|
||||
func (s *serverConfigClient) Get() []byte {
|
||||
return s.raw
|
||||
}
|
|
@ -1,93 +0,0 @@
|
|||
package handshake
|
||||
|
||||
// A Tag in the QUIC crypto
|
||||
type Tag uint32
|
||||
|
||||
const (
|
||||
// TagCHLO is a client hello
|
||||
TagCHLO Tag = 'C' + 'H'<<8 + 'L'<<16 + 'O'<<24
|
||||
// TagREJ is a server hello rejection
|
||||
TagREJ Tag = 'R' + 'E'<<8 + 'J'<<16
|
||||
// TagSCFG is a server config
|
||||
TagSCFG Tag = 'S' + 'C'<<8 + 'F'<<16 + 'G'<<24
|
||||
|
||||
// TagPAD is padding
|
||||
TagPAD Tag = 'P' + 'A'<<8 + 'D'<<16
|
||||
// TagSNI is the server name indication
|
||||
TagSNI Tag = 'S' + 'N'<<8 + 'I'<<16
|
||||
// TagVER is the QUIC version
|
||||
TagVER Tag = 'V' + 'E'<<8 + 'R'<<16
|
||||
// TagCCS are the hashes of the common certificate sets
|
||||
TagCCS Tag = 'C' + 'C'<<8 + 'S'<<16
|
||||
// TagCCRT are the hashes of the cached certificates
|
||||
TagCCRT Tag = 'C' + 'C'<<8 + 'R'<<16 + 'T'<<24
|
||||
// TagMSPC is max streams per connection
|
||||
TagMSPC Tag = 'M' + 'S'<<8 + 'P'<<16 + 'C'<<24
|
||||
// TagMIDS is max incoming dyanamic streams
|
||||
TagMIDS Tag = 'M' + 'I'<<8 + 'D'<<16 + 'S'<<24
|
||||
// TagUAID is the user agent ID
|
||||
TagUAID Tag = 'U' + 'A'<<8 + 'I'<<16 + 'D'<<24
|
||||
// TagSVID is the server ID (unofficial tag by us :)
|
||||
TagSVID Tag = 'S' + 'V'<<8 + 'I'<<16 + 'D'<<24
|
||||
// TagTCID is truncation of the connection ID
|
||||
TagTCID Tag = 'T' + 'C'<<8 + 'I'<<16 + 'D'<<24
|
||||
// TagPDMD is the proof demand
|
||||
TagPDMD Tag = 'P' + 'D'<<8 + 'M'<<16 + 'D'<<24
|
||||
// TagSRBF is the socket receive buffer
|
||||
TagSRBF Tag = 'S' + 'R'<<8 + 'B'<<16 + 'F'<<24
|
||||
// TagICSL is the idle connection state lifetime
|
||||
TagICSL Tag = 'I' + 'C'<<8 + 'S'<<16 + 'L'<<24
|
||||
// TagNONP is the client proof nonce
|
||||
TagNONP Tag = 'N' + 'O'<<8 + 'N'<<16 + 'P'<<24
|
||||
// TagSCLS is the silently close timeout
|
||||
TagSCLS Tag = 'S' + 'C'<<8 + 'L'<<16 + 'S'<<24
|
||||
// TagCSCT is the signed cert timestamp (RFC6962) of leaf cert
|
||||
TagCSCT Tag = 'C' + 'S'<<8 + 'C'<<16 + 'T'<<24
|
||||
// TagCOPT are the connection options
|
||||
TagCOPT Tag = 'C' + 'O'<<8 + 'P'<<16 + 'T'<<24
|
||||
// TagCFCW is the initial session/connection flow control receive window
|
||||
TagCFCW Tag = 'C' + 'F'<<8 + 'C'<<16 + 'W'<<24
|
||||
// TagSFCW is the initial stream flow control receive window.
|
||||
TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24
|
||||
|
||||
// TagNSTP is the no STOP_WAITING experiment
|
||||
// currently unsupported by quic-go
|
||||
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24
|
||||
|
||||
// TagSTK is the source-address token
|
||||
TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16
|
||||
// TagSNO is the server nonce
|
||||
TagSNO Tag = 'S' + 'N'<<8 + 'O'<<16
|
||||
// TagPROF is the server proof
|
||||
TagPROF Tag = 'P' + 'R'<<8 + 'O'<<16 + 'F'<<24
|
||||
|
||||
// TagNONC is the client nonce
|
||||
TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24
|
||||
// TagXLCT is the expected leaf certificate
|
||||
TagXLCT Tag = 'X' + 'L'<<8 + 'C'<<16 + 'T'<<24
|
||||
|
||||
// TagSCID is the server config ID
|
||||
TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24
|
||||
// TagKEXS is the list of key exchange algos
|
||||
TagKEXS Tag = 'K' + 'E'<<8 + 'X'<<16 + 'S'<<24
|
||||
// TagAEAD is the list of AEAD algos
|
||||
TagAEAD Tag = 'A' + 'E'<<8 + 'A'<<16 + 'D'<<24
|
||||
// TagPUBS is the public value for the KEX
|
||||
TagPUBS Tag = 'P' + 'U'<<8 + 'B'<<16 + 'S'<<24
|
||||
// TagOBIT is the client orbit
|
||||
TagOBIT Tag = 'O' + 'B'<<8 + 'I'<<16 + 'T'<<24
|
||||
// TagEXPY is the server config expiry
|
||||
TagEXPY Tag = 'E' + 'X'<<8 + 'P'<<16 + 'Y'<<24
|
||||
// TagCERT is the CERT data
|
||||
TagCERT Tag = 0xff545243
|
||||
|
||||
// TagSHLO is the server hello
|
||||
TagSHLO Tag = 'S' + 'H'<<8 + 'L'<<16 + 'O'<<24
|
||||
|
||||
// TagPRST is the public reset tag
|
||||
TagPRST Tag = 'P' + 'R'<<8 + 'S'<<16 + 'T'<<24
|
||||
// TagRSEQ is the public reset rejected packet number
|
||||
TagRSEQ Tag = 'R' + 'S'<<8 + 'E'<<16 + 'Q'<<24
|
||||
// TagRNON is the public reset nonce
|
||||
TagRNON Tag = 'R' + 'N'<<8 + 'O'<<16 + 'N'<<24
|
||||
)
|
|
@ -6,26 +6,12 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type transportParameterID uint16
|
||||
|
||||
const quicTLSExtensionType = 0xff5
|
||||
|
||||
const (
|
||||
initialMaxStreamDataParameterID transportParameterID = 0x0
|
||||
initialMaxDataParameterID transportParameterID = 0x1
|
||||
initialMaxBidiStreamsParameterID transportParameterID = 0x2
|
||||
idleTimeoutParameterID transportParameterID = 0x3
|
||||
maxPacketSizeParameterID transportParameterID = 0x5
|
||||
statelessResetTokenParameterID transportParameterID = 0x6
|
||||
initialMaxUniStreamsParameterID transportParameterID = 0x8
|
||||
disableMigrationParameterID transportParameterID = 0x9
|
||||
)
|
||||
|
||||
type clientHelloTransportParameters struct {
|
||||
InitialVersion protocol.VersionNumber
|
||||
Parameters TransportParameters
|
||||
|
@ -52,7 +38,7 @@ func (p *clientHelloTransportParameters) Unmarshal(data []byte) error {
|
|||
if len(data) != paramsLen {
|
||||
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
|
||||
}
|
||||
return p.Parameters.unmarshal(data)
|
||||
return p.Parameters.unmarshal(data, protocol.PerspectiveClient)
|
||||
}
|
||||
|
||||
type encryptedExtensionsTransportParameters struct {
|
||||
|
@ -100,24 +86,5 @@ func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error {
|
|||
if len(data) != paramsLen {
|
||||
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
|
||||
}
|
||||
return p.Parameters.unmarshal(data)
|
||||
}
|
||||
|
||||
type tlsExtensionBody struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
var _ mint.ExtensionBody = &tlsExtensionBody{}
|
||||
|
||||
func (e *tlsExtensionBody) Type() mint.ExtensionType {
|
||||
return quicTLSExtensionType
|
||||
}
|
||||
|
||||
func (e *tlsExtensionBody) Marshal() ([]byte, error) {
|
||||
return e.data, nil
|
||||
}
|
||||
|
||||
func (e *tlsExtensionBody) Unmarshal(data []byte) (int, error) {
|
||||
e.data = data
|
||||
return len(data), nil
|
||||
return p.Parameters.unmarshal(data, protocol.PerspectiveServer)
|
||||
}
|
||||
|
|
|
@ -4,17 +4,17 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type extensionHandlerClient struct {
|
||||
ourParams *TransportParameters
|
||||
paramsChan chan TransportParameters
|
||||
paramsChan chan<- TransportParameters
|
||||
|
||||
origConnID protocol.ConnectionID
|
||||
initialVersion protocol.VersionNumber
|
||||
supportedVersions []protocol.VersionNumber
|
||||
version protocol.VersionNumber
|
||||
|
@ -22,17 +22,17 @@ type extensionHandlerClient struct {
|
|||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
|
||||
var _ TLSExtensionHandler = &extensionHandlerClient{}
|
||||
var _ tlsExtensionHandler = &extensionHandlerClient{}
|
||||
|
||||
// NewExtensionHandlerClient creates a new extension handler for the client.
|
||||
func NewExtensionHandlerClient(
|
||||
// newExtensionHandlerClient creates a new extension handler for the client.
|
||||
func newExtensionHandlerClient(
|
||||
params *TransportParameters,
|
||||
origConnID protocol.ConnectionID,
|
||||
initialVersion protocol.VersionNumber,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
version protocol.VersionNumber,
|
||||
logger utils.Logger,
|
||||
) TLSExtensionHandler {
|
||||
) (tlsExtensionHandler, <-chan TransportParameters) {
|
||||
// The client reads the transport parameters from the Encrypted Extensions message.
|
||||
// The paramsChan is used in the session's run loop's select statement.
|
||||
// We have to use an unbuffered channel here to make sure that the session actually processes the transport parameters immediately.
|
||||
|
@ -40,48 +40,48 @@ func NewExtensionHandlerClient(
|
|||
return &extensionHandlerClient{
|
||||
ourParams: params,
|
||||
paramsChan: paramsChan,
|
||||
origConnID: origConnID,
|
||||
initialVersion: initialVersion,
|
||||
supportedVersions: supportedVersions,
|
||||
version: version,
|
||||
logger: logger,
|
||||
}
|
||||
}, paramsChan
|
||||
}
|
||||
|
||||
func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
||||
if hType != mint.HandshakeTypeClientHello {
|
||||
func (h *extensionHandlerClient) GetExtensions(msgType uint8) []qtls.Extension {
|
||||
if messageType(msgType) != typeClientHello {
|
||||
return nil
|
||||
}
|
||||
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
|
||||
chtp := &clientHelloTransportParameters{
|
||||
InitialVersion: h.initialVersion,
|
||||
Parameters: *h.ourParams,
|
||||
}
|
||||
return el.Add(&tlsExtensionBody{data: chtp.Marshal()})
|
||||
return []qtls.Extension{{
|
||||
Type: quicTLSExtensionType,
|
||||
Data: (&clientHelloTransportParameters{
|
||||
InitialVersion: h.initialVersion,
|
||||
Parameters: *h.ourParams,
|
||||
}).Marshal(),
|
||||
}}
|
||||
}
|
||||
|
||||
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
||||
ext := &tlsExtensionBody{}
|
||||
found, err := el.Find(ext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hType != mint.HandshakeTypeEncryptedExtensions {
|
||||
if found {
|
||||
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
|
||||
}
|
||||
func (h *extensionHandlerClient) ReceivedExtensions(msgType uint8, exts []qtls.Extension) error {
|
||||
if messageType(msgType) != typeEncryptedExtensions {
|
||||
return nil
|
||||
}
|
||||
|
||||
// hType == mint.HandshakeTypeEncryptedExtensions
|
||||
var found bool
|
||||
eetp := &encryptedExtensionsTransportParameters{}
|
||||
for _, ext := range exts {
|
||||
if ext.Type != quicTLSExtensionType {
|
||||
continue
|
||||
}
|
||||
if err := eetp.Unmarshal(ext.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
found = true
|
||||
}
|
||||
if !found {
|
||||
return errors.New("EncryptedExtensions message didn't contain a QUIC extension")
|
||||
}
|
||||
|
||||
eetp := &encryptedExtensionsTransportParameters{}
|
||||
if err := eetp.Unmarshal(ext.data); err != nil {
|
||||
return err
|
||||
}
|
||||
// check that the negotiated_version is the current version
|
||||
if eetp.NegotiatedVersion != h.version {
|
||||
return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version")
|
||||
|
@ -98,15 +98,16 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
|
|||
}
|
||||
}
|
||||
|
||||
params := eetp.Parameters
|
||||
// check that the server sent a stateless reset token
|
||||
if len(eetp.Parameters.StatelessResetToken) == 0 {
|
||||
if len(params.StatelessResetToken) == 0 {
|
||||
return errors.New("server didn't sent stateless_reset_token")
|
||||
}
|
||||
h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters)
|
||||
h.paramsChan <- eetp.Parameters
|
||||
// check the Retry token
|
||||
if !h.origConnID.Equal(params.OriginalConnectionID) {
|
||||
return fmt.Errorf("expected original_connection_id to equal %s, is %s", h.origConnID, params.OriginalConnectionID)
|
||||
}
|
||||
h.logger.Debugf("Received Transport Parameters: %s", ¶ms)
|
||||
h.paramsChan <- params
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters {
|
||||
return h.paramsChan
|
||||
}
|
||||
|
|
|
@ -2,18 +2,16 @@ package handshake
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
type extensionHandlerServer struct {
|
||||
ourParams *TransportParameters
|
||||
paramsChan chan TransportParameters
|
||||
paramsChan chan<- TransportParameters
|
||||
|
||||
version protocol.VersionNumber
|
||||
supportedVersions []protocol.VersionNumber
|
||||
|
@ -21,62 +19,60 @@ type extensionHandlerServer struct {
|
|||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
|
||||
var _ TLSExtensionHandler = &extensionHandlerServer{}
|
||||
var _ tlsExtensionHandler = &extensionHandlerServer{}
|
||||
|
||||
// NewExtensionHandlerServer creates a new extension handler for the server
|
||||
func NewExtensionHandlerServer(
|
||||
// newExtensionHandlerServer creates a new extension handler for the server
|
||||
func newExtensionHandlerServer(
|
||||
params *TransportParameters,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
version protocol.VersionNumber,
|
||||
logger utils.Logger,
|
||||
) TLSExtensionHandler {
|
||||
) (tlsExtensionHandler, <-chan TransportParameters) {
|
||||
// Processing the ClientHello is performed statelessly (and from a single go-routine).
|
||||
// Therefore, we have to use a buffered chan to pass the transport parameters to that go routine.
|
||||
paramsChan := make(chan TransportParameters, 1)
|
||||
paramsChan := make(chan TransportParameters)
|
||||
return &extensionHandlerServer{
|
||||
ourParams: params,
|
||||
paramsChan: paramsChan,
|
||||
supportedVersions: supportedVersions,
|
||||
version: version,
|
||||
logger: logger,
|
||||
}
|
||||
}, paramsChan
|
||||
}
|
||||
|
||||
func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
||||
if hType != mint.HandshakeTypeEncryptedExtensions {
|
||||
func (h *extensionHandlerServer) GetExtensions(msgType uint8) []qtls.Extension {
|
||||
if messageType(msgType) != typeEncryptedExtensions {
|
||||
return nil
|
||||
}
|
||||
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
|
||||
eetp := &encryptedExtensionsTransportParameters{
|
||||
NegotiatedVersion: h.version,
|
||||
SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
|
||||
Parameters: *h.ourParams,
|
||||
}
|
||||
return el.Add(&tlsExtensionBody{data: eetp.Marshal()})
|
||||
return []qtls.Extension{{
|
||||
Type: quicTLSExtensionType,
|
||||
Data: (&encryptedExtensionsTransportParameters{
|
||||
NegotiatedVersion: h.version,
|
||||
SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
|
||||
Parameters: *h.ourParams,
|
||||
}).Marshal(),
|
||||
}}
|
||||
}
|
||||
|
||||
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
||||
ext := &tlsExtensionBody{}
|
||||
found, err := el.Find(ext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hType != mint.HandshakeTypeClientHello {
|
||||
if found {
|
||||
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
|
||||
}
|
||||
func (h *extensionHandlerServer) ReceivedExtensions(msgType uint8, exts []qtls.Extension) error {
|
||||
if messageType(msgType) != typeClientHello {
|
||||
return nil
|
||||
}
|
||||
|
||||
var found bool
|
||||
chtp := &clientHelloTransportParameters{}
|
||||
for _, ext := range exts {
|
||||
if ext.Type != quicTLSExtensionType {
|
||||
continue
|
||||
}
|
||||
if err := chtp.Unmarshal(ext.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
found = true
|
||||
}
|
||||
if !found {
|
||||
return errors.New("ClientHello didn't contain a QUIC extension")
|
||||
}
|
||||
chtp := &clientHelloTransportParameters{}
|
||||
if err := chtp.Unmarshal(ext.data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// perform the stateless version negotiation validation:
|
||||
// make sure that we would have sent a Version Negotiation Packet if the client offered the initial version
|
||||
|
@ -84,17 +80,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
|
|||
if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) {
|
||||
return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version")
|
||||
}
|
||||
|
||||
// check that the client didn't send a stateless reset token
|
||||
if len(chtp.Parameters.StatelessResetToken) != 0 {
|
||||
// TODO: return the correct error type
|
||||
return errors.New("client sent a stateless reset token")
|
||||
}
|
||||
h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters)
|
||||
h.paramsChan <- chtp.Parameters
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters {
|
||||
return h.paramsChan
|
||||
}
|
||||
|
|
301
vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go
generated
vendored
301
vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go
generated
vendored
|
@ -2,194 +2,190 @@ package handshake
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// errMalformedTag is returned when the tag value cannot be read
|
||||
var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
|
||||
type transportParameterID uint16
|
||||
|
||||
const (
|
||||
originalConnectionIDParameterID transportParameterID = 0x0
|
||||
idleTimeoutParameterID transportParameterID = 0x1
|
||||
statelessResetTokenParameterID transportParameterID = 0x2
|
||||
maxPacketSizeParameterID transportParameterID = 0x3
|
||||
initialMaxDataParameterID transportParameterID = 0x4
|
||||
initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5
|
||||
initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6
|
||||
initialMaxStreamDataUniParameterID transportParameterID = 0x7
|
||||
initialMaxStreamsBidiParameterID transportParameterID = 0x8
|
||||
initialMaxStreamsUniParameterID transportParameterID = 0x9
|
||||
disableMigrationParameterID transportParameterID = 0xc
|
||||
)
|
||||
|
||||
// TransportParameters are parameters sent to the peer during the handshake
|
||||
type TransportParameters struct {
|
||||
StreamFlowControlWindow protocol.ByteCount
|
||||
ConnectionFlowControlWindow protocol.ByteCount
|
||||
InitialMaxStreamDataBidiLocal protocol.ByteCount
|
||||
InitialMaxStreamDataBidiRemote protocol.ByteCount
|
||||
InitialMaxStreamDataUni protocol.ByteCount
|
||||
InitialMaxData protocol.ByteCount
|
||||
|
||||
MaxPacketSize protocol.ByteCount
|
||||
|
||||
MaxUniStreams uint16 // only used for IETF QUIC
|
||||
MaxBidiStreams uint16 // only used for IETF QUIC
|
||||
MaxStreams uint32 // only used for gQUIC
|
||||
MaxUniStreams uint64
|
||||
MaxBidiStreams uint64
|
||||
|
||||
OmitConnectionID bool // only used for gQUIC
|
||||
IdleTimeout time.Duration
|
||||
DisableMigration bool // only used for IETF QUIC
|
||||
StatelessResetToken []byte // only used for IETF QUIC
|
||||
IdleTimeout time.Duration
|
||||
DisableMigration bool
|
||||
|
||||
StatelessResetToken []byte
|
||||
OriginalConnectionID protocol.ConnectionID
|
||||
}
|
||||
|
||||
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
|
||||
func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) {
|
||||
params := &TransportParameters{}
|
||||
if value, ok := tags[TagTCID]; ok {
|
||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return nil, errMalformedTag
|
||||
func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective) error {
|
||||
// needed to check that every parameter is only sent at most once
|
||||
var parameterIDs []transportParameterID
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
for r.Len() >= 4 {
|
||||
paramIDInt, _ := utils.BigEndian.ReadUint16(r)
|
||||
paramID := transportParameterID(paramIDInt)
|
||||
paramLen, _ := utils.BigEndian.ReadUint16(r)
|
||||
parameterIDs = append(parameterIDs, paramID)
|
||||
switch paramID {
|
||||
case initialMaxStreamDataBidiLocalParameterID,
|
||||
initialMaxStreamDataBidiRemoteParameterID,
|
||||
initialMaxStreamDataUniParameterID,
|
||||
initialMaxDataParameterID,
|
||||
initialMaxStreamsBidiParameterID,
|
||||
initialMaxStreamsUniParameterID,
|
||||
idleTimeoutParameterID,
|
||||
maxPacketSizeParameterID:
|
||||
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if r.Len() < int(paramLen) {
|
||||
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen)
|
||||
}
|
||||
switch paramID {
|
||||
case disableMigrationParameterID:
|
||||
if paramLen != 0 {
|
||||
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
|
||||
}
|
||||
p.DisableMigration = true
|
||||
case statelessResetTokenParameterID:
|
||||
if sentBy == protocol.PerspectiveClient {
|
||||
return errors.New("client sent a stateless_reset_token")
|
||||
}
|
||||
if paramLen != 16 {
|
||||
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
|
||||
}
|
||||
b := make([]byte, 16)
|
||||
r.Read(b)
|
||||
p.StatelessResetToken = b
|
||||
case originalConnectionIDParameterID:
|
||||
if sentBy == protocol.PerspectiveClient {
|
||||
return errors.New("client sent an original_connection_id")
|
||||
}
|
||||
p.OriginalConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
|
||||
default:
|
||||
r.Seek(int64(paramLen), io.SeekCurrent)
|
||||
}
|
||||
}
|
||||
params.OmitConnectionID = (v == 0)
|
||||
}
|
||||
if value, ok := tags[TagMIDS]; ok {
|
||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return nil, errMalformedTag
|
||||
|
||||
// check that every transport parameter was sent at most once
|
||||
sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] })
|
||||
for i := 0; i < len(parameterIDs)-1; i++ {
|
||||
if parameterIDs[i] == parameterIDs[i+1] {
|
||||
return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i])
|
||||
}
|
||||
params.MaxStreams = v
|
||||
}
|
||||
if value, ok := tags[TagICSL]; ok {
|
||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return nil, errMalformedTag
|
||||
}
|
||||
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second)
|
||||
|
||||
if r.Len() != 0 {
|
||||
return fmt.Errorf("should have read all data. Still have %d bytes", r.Len())
|
||||
}
|
||||
if value, ok := tags[TagSFCW]; ok {
|
||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return nil, errMalformedTag
|
||||
}
|
||||
params.StreamFlowControlWindow = protocol.ByteCount(v)
|
||||
}
|
||||
if value, ok := tags[TagCFCW]; ok {
|
||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return nil, errMalformedTag
|
||||
}
|
||||
params.ConnectionFlowControlWindow = protocol.ByteCount(v)
|
||||
}
|
||||
return params, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake.
|
||||
func (p *TransportParameters) getHelloMap() map[Tag][]byte {
|
||||
sfcw := bytes.NewBuffer([]byte{})
|
||||
utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow))
|
||||
cfcw := bytes.NewBuffer([]byte{})
|
||||
utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow))
|
||||
mids := bytes.NewBuffer([]byte{})
|
||||
utils.LittleEndian.WriteUint32(mids, p.MaxStreams)
|
||||
icsl := bytes.NewBuffer([]byte{})
|
||||
utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second))
|
||||
|
||||
tags := map[Tag][]byte{
|
||||
TagICSL: icsl.Bytes(),
|
||||
TagMIDS: mids.Bytes(),
|
||||
TagCFCW: cfcw.Bytes(),
|
||||
TagSFCW: sfcw.Bytes(),
|
||||
func (p *TransportParameters) readNumericTransportParameter(
|
||||
r *bytes.Reader,
|
||||
paramID transportParameterID,
|
||||
expectedLen int,
|
||||
) error {
|
||||
remainingLen := r.Len()
|
||||
val, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err)
|
||||
}
|
||||
if p.OmitConnectionID {
|
||||
tags[TagTCID] = []byte{0, 0, 0, 0}
|
||||
if remainingLen-r.Len() != expectedLen {
|
||||
return fmt.Errorf("inconsistent transport parameter length for %d", paramID)
|
||||
}
|
||||
return tags
|
||||
}
|
||||
|
||||
func (p *TransportParameters) unmarshal(data []byte) error {
|
||||
var foundIdleTimeout bool
|
||||
|
||||
for len(data) >= 4 {
|
||||
paramID := binary.BigEndian.Uint16(data[:2])
|
||||
paramLen := int(binary.BigEndian.Uint16(data[2:4]))
|
||||
data = data[4:]
|
||||
if len(data) < paramLen {
|
||||
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(data), paramLen)
|
||||
switch paramID {
|
||||
case initialMaxStreamDataBidiLocalParameterID:
|
||||
p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val)
|
||||
case initialMaxStreamDataBidiRemoteParameterID:
|
||||
p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val)
|
||||
case initialMaxStreamDataUniParameterID:
|
||||
p.InitialMaxStreamDataUni = protocol.ByteCount(val)
|
||||
case initialMaxDataParameterID:
|
||||
p.InitialMaxData = protocol.ByteCount(val)
|
||||
case initialMaxStreamsBidiParameterID:
|
||||
p.MaxBidiStreams = val
|
||||
case initialMaxStreamsUniParameterID:
|
||||
p.MaxUniStreams = val
|
||||
case idleTimeoutParameterID:
|
||||
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Second)
|
||||
case maxPacketSizeParameterID:
|
||||
if val < 1200 {
|
||||
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val)
|
||||
}
|
||||
switch transportParameterID(paramID) {
|
||||
case initialMaxStreamDataParameterID:
|
||||
if paramLen != 4 {
|
||||
return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", paramLen)
|
||||
}
|
||||
p.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
|
||||
case initialMaxDataParameterID:
|
||||
if paramLen != 4 {
|
||||
return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", paramLen)
|
||||
}
|
||||
p.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
|
||||
case initialMaxBidiStreamsParameterID:
|
||||
if paramLen != 2 {
|
||||
return fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", paramLen)
|
||||
}
|
||||
p.MaxBidiStreams = binary.BigEndian.Uint16(data[:2])
|
||||
case initialMaxUniStreamsParameterID:
|
||||
if paramLen != 2 {
|
||||
return fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", paramLen)
|
||||
}
|
||||
p.MaxUniStreams = binary.BigEndian.Uint16(data[:2])
|
||||
case idleTimeoutParameterID:
|
||||
foundIdleTimeout = true
|
||||
if paramLen != 2 {
|
||||
return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", paramLen)
|
||||
}
|
||||
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(data[:2]))*time.Second)
|
||||
case maxPacketSizeParameterID:
|
||||
if paramLen != 2 {
|
||||
return fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", paramLen)
|
||||
}
|
||||
maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(data[:2]))
|
||||
if maxPacketSize < 1200 {
|
||||
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
|
||||
}
|
||||
p.MaxPacketSize = maxPacketSize
|
||||
case disableMigrationParameterID:
|
||||
if paramLen != 0 {
|
||||
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
|
||||
}
|
||||
p.DisableMigration = true
|
||||
case statelessResetTokenParameterID:
|
||||
if paramLen != 16 {
|
||||
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
|
||||
}
|
||||
p.StatelessResetToken = data[:16]
|
||||
}
|
||||
data = data[paramLen:]
|
||||
}
|
||||
|
||||
if len(data) != 0 {
|
||||
return fmt.Errorf("should have read all data. Still have %d bytes", len(data))
|
||||
}
|
||||
if !foundIdleTimeout {
|
||||
return errors.New("missing parameter")
|
||||
p.MaxPacketSize = protocol.ByteCount(val)
|
||||
default:
|
||||
return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TransportParameters) marshal(b *bytes.Buffer) {
|
||||
// initial_max_stream_data
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 4)
|
||||
utils.BigEndian.WriteUint32(b, uint32(p.StreamFlowControlWindow))
|
||||
// initial_max_stream_data_bidi_local
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiLocalParameterID))
|
||||
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataBidiLocal))))
|
||||
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataBidiLocal))
|
||||
// initial_max_stream_data_bidi_remote
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiRemoteParameterID))
|
||||
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataBidiRemote))))
|
||||
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataBidiRemote))
|
||||
// initial_max_stream_data_uni
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataUniParameterID))
|
||||
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataUni))))
|
||||
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataUni))
|
||||
// initial_max_data
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 4)
|
||||
utils.BigEndian.WriteUint32(b, uint32(p.ConnectionFlowControlWindow))
|
||||
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxData))))
|
||||
utils.WriteVarInt(b, uint64(p.InitialMaxData))
|
||||
// initial_max_bidi_streams
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxBidiStreamsParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 2)
|
||||
utils.BigEndian.WriteUint16(b, p.MaxBidiStreams)
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamsBidiParameterID))
|
||||
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(p.MaxBidiStreams)))
|
||||
utils.WriteVarInt(b, p.MaxBidiStreams)
|
||||
// initial_max_uni_streams
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxUniStreamsParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 2)
|
||||
utils.BigEndian.WriteUint16(b, p.MaxUniStreams)
|
||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamsUniParameterID))
|
||||
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(p.MaxUniStreams)))
|
||||
utils.WriteVarInt(b, p.MaxUniStreams)
|
||||
// idle_timeout
|
||||
utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 2)
|
||||
utils.BigEndian.WriteUint16(b, uint16(p.IdleTimeout/time.Second))
|
||||
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.IdleTimeout/time.Second))))
|
||||
utils.WriteVarInt(b, uint64(p.IdleTimeout/time.Second))
|
||||
// max_packet_size
|
||||
utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID))
|
||||
utils.BigEndian.WriteUint16(b, 2)
|
||||
utils.BigEndian.WriteUint16(b, uint16(protocol.MaxReceivePacketSize))
|
||||
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(protocol.MaxReceivePacketSize))))
|
||||
utils.WriteVarInt(b, uint64(protocol.MaxReceivePacketSize))
|
||||
// disable_migration
|
||||
if p.DisableMigration {
|
||||
utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID))
|
||||
|
@ -200,10 +196,15 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) {
|
|||
utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes
|
||||
b.Write(p.StatelessResetToken)
|
||||
}
|
||||
// original_connection_id
|
||||
if p.OriginalConnectionID.Len() > 0 {
|
||||
utils.BigEndian.WriteUint16(b, uint16(originalConnectionIDParameterID))
|
||||
utils.BigEndian.WriteUint16(b, uint16(p.OriginalConnectionID.Len()))
|
||||
b.Write(p.OriginalConnectionID.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation, intended for logging.
|
||||
// It should only used for IETF QUIC.
|
||||
func (p *TransportParameters) String() string {
|
||||
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
|
||||
return fmt.Sprintf("&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
|
||||
}
|
||||
|
|
|
@ -86,30 +86,6 @@ func (mr *MockSentPacketHandlerMockRecorder) GetLowestPacketNotConfirmedAcked()
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked))
|
||||
}
|
||||
|
||||
// GetPacketNumberLen mocks base method
|
||||
func (m *MockSentPacketHandler) GetPacketNumberLen(arg0 protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
ret := m.ctrl.Call(m, "GetPacketNumberLen", arg0)
|
||||
ret0, _ := ret[0].(protocol.PacketNumberLen)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetPacketNumberLen indicates an expected call of GetPacketNumberLen
|
||||
func (mr *MockSentPacketHandlerMockRecorder) GetPacketNumberLen(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPacketNumberLen", reflect.TypeOf((*MockSentPacketHandler)(nil).GetPacketNumberLen), arg0)
|
||||
}
|
||||
|
||||
// GetStopWaitingFrame mocks base method
|
||||
func (m *MockSentPacketHandler) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame {
|
||||
ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0)
|
||||
ret0, _ := ret[0].(*wire.StopWaitingFrame)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetStopWaitingFrame indicates an expected call of GetStopWaitingFrame
|
||||
func (mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStopWaitingFrame", reflect.TypeOf((*MockSentPacketHandler)(nil).GetStopWaitingFrame), arg0)
|
||||
}
|
||||
|
||||
// OnAlarm mocks base method
|
||||
func (m *MockSentPacketHandler) OnAlarm() error {
|
||||
ret := m.ctrl.Call(m, "OnAlarm")
|
||||
|
@ -122,6 +98,31 @@ func (mr *MockSentPacketHandlerMockRecorder) OnAlarm() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAlarm", reflect.TypeOf((*MockSentPacketHandler)(nil).OnAlarm))
|
||||
}
|
||||
|
||||
// PeekPacketNumber mocks base method
|
||||
func (m *MockSentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||
ret := m.ctrl.Call(m, "PeekPacketNumber")
|
||||
ret0, _ := ret[0].(protocol.PacketNumber)
|
||||
ret1, _ := ret[1].(protocol.PacketNumberLen)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// PeekPacketNumber indicates an expected call of PeekPacketNumber
|
||||
func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber))
|
||||
}
|
||||
|
||||
// PopPacketNumber mocks base method
|
||||
func (m *MockSentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
||||
ret := m.ctrl.Call(m, "PopPacketNumber")
|
||||
ret0, _ := ret[0].(protocol.PacketNumber)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PopPacketNumber indicates an expected call of PopPacketNumber
|
||||
func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber))
|
||||
}
|
||||
|
||||
// ReceivedAck mocks base method
|
||||
func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.PacketNumber, arg2 protocol.EncryptionLevel, arg3 time.Time) error {
|
||||
ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2, arg3)
|
||||
|
|
149
vendor/github.com/lucas-clemente/quic-go/internal/mocks/crypto_setup.go
generated
vendored
Normal file
149
vendor/github.com/lucas-clemente/quic-go/internal/mocks/crypto_setup.go
generated
vendored
Normal file
|
@ -0,0 +1,149 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: CryptoSetup)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockCryptoSetup is a mock of CryptoSetup interface
|
||||
type MockCryptoSetup struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCryptoSetupMockRecorder
|
||||
}
|
||||
|
||||
// MockCryptoSetupMockRecorder is the mock recorder for MockCryptoSetup
|
||||
type MockCryptoSetupMockRecorder struct {
|
||||
mock *MockCryptoSetup
|
||||
}
|
||||
|
||||
// NewMockCryptoSetup creates a new mock instance
|
||||
func NewMockCryptoSetup(ctrl *gomock.Controller) *MockCryptoSetup {
|
||||
mock := &MockCryptoSetup{ctrl: ctrl}
|
||||
mock.recorder = &MockCryptoSetupMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockCryptoSetup) Close() error {
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close))
|
||||
}
|
||||
|
||||
// ConnectionState mocks base method
|
||||
func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState {
|
||||
ret := m.ctrl.Call(m, "ConnectionState")
|
||||
ret0, _ := ret[0].(handshake.ConnectionState)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ConnectionState indicates an expected call of ConnectionState
|
||||
func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
|
||||
}
|
||||
|
||||
// GetSealer mocks base method
|
||||
func (m *MockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) {
|
||||
ret := m.ctrl.Call(m, "GetSealer")
|
||||
ret0, _ := ret[0].(protocol.EncryptionLevel)
|
||||
ret1, _ := ret[1].(handshake.Sealer)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSealer indicates an expected call of GetSealer
|
||||
func (mr *MockCryptoSetupMockRecorder) GetSealer() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetSealer))
|
||||
}
|
||||
|
||||
// GetSealerWithEncryptionLevel mocks base method
|
||||
func (m *MockCryptoSetup) GetSealerWithEncryptionLevel(arg0 protocol.EncryptionLevel) (handshake.Sealer, error) {
|
||||
ret := m.ctrl.Call(m, "GetSealerWithEncryptionLevel", arg0)
|
||||
ret0, _ := ret[0].(handshake.Sealer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSealerWithEncryptionLevel indicates an expected call of GetSealerWithEncryptionLevel
|
||||
func (mr *MockCryptoSetupMockRecorder) GetSealerWithEncryptionLevel(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealerWithEncryptionLevel", reflect.TypeOf((*MockCryptoSetup)(nil).GetSealerWithEncryptionLevel), arg0)
|
||||
}
|
||||
|
||||
// HandleMessage mocks base method
|
||||
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
|
||||
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// HandleMessage indicates an expected call of HandleMessage
|
||||
func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
|
||||
}
|
||||
|
||||
// Open1RTT mocks base method
|
||||
func (m *MockCryptoSetup) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||
ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Open1RTT indicates an expected call of Open1RTT
|
||||
func (mr *MockCryptoSetupMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockCryptoSetup)(nil).Open1RTT), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// OpenHandshake mocks base method
|
||||
func (m *MockCryptoSetup) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||
ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenHandshake indicates an expected call of OpenHandshake
|
||||
func (mr *MockCryptoSetupMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).OpenHandshake), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// OpenInitial mocks base method
|
||||
func (m *MockCryptoSetup) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||
ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenInitial indicates an expected call of OpenInitial
|
||||
func (mr *MockCryptoSetupMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockCryptoSetup)(nil).OpenInitial), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// RunHandshake mocks base method
|
||||
func (m *MockCryptoSetup) RunHandshake() error {
|
||||
ret := m.ctrl.Call(m, "RunHandshake")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RunHandshake indicates an expected call of RunHandshake
|
||||
func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake))
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package mocks
|
||||
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
||||
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
|
||||
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Sealer)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockSealer is a mock of Sealer interface
|
||||
type MockSealer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSealerMockRecorder
|
||||
}
|
||||
|
||||
// MockSealerMockRecorder is the mock recorder for MockSealer
|
||||
type MockSealerMockRecorder struct {
|
||||
mock *MockSealer
|
||||
}
|
||||
|
||||
// NewMockSealer creates a new mock instance
|
||||
func NewMockSealer(ctrl *gomock.Controller) *MockSealer {
|
||||
mock := &MockSealer{ctrl: ctrl}
|
||||
mock.recorder = &MockSealerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockSealer) EXPECT() *MockSealerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Overhead mocks base method
|
||||
func (m *MockSealer) Overhead() int {
|
||||
ret := m.ctrl.Call(m, "Overhead")
|
||||
ret0, _ := ret[0].(int)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Overhead indicates an expected call of Overhead
|
||||
func (mr *MockSealerMockRecorder) Overhead() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockSealer)(nil).Overhead))
|
||||
}
|
||||
|
||||
// Seal mocks base method
|
||||
func (m *MockSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte {
|
||||
ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Seal indicates an expected call of Seal
|
||||
func (mr *MockSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockSealer)(nil).Seal), arg0, arg1, arg2, arg3)
|
||||
}
|
72
vendor/github.com/lucas-clemente/quic-go/internal/mocks/tls_extension_handler.go
generated
vendored
72
vendor/github.com/lucas-clemente/quic-go/internal/mocks/tls_extension_handler.go
generated
vendored
|
@ -1,72 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: TLSExtensionHandler)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
mint "github.com/bifurcation/mint"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
)
|
||||
|
||||
// MockTLSExtensionHandler is a mock of TLSExtensionHandler interface
|
||||
type MockTLSExtensionHandler struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockTLSExtensionHandlerMockRecorder
|
||||
}
|
||||
|
||||
// MockTLSExtensionHandlerMockRecorder is the mock recorder for MockTLSExtensionHandler
|
||||
type MockTLSExtensionHandlerMockRecorder struct {
|
||||
mock *MockTLSExtensionHandler
|
||||
}
|
||||
|
||||
// NewMockTLSExtensionHandler creates a new mock instance
|
||||
func NewMockTLSExtensionHandler(ctrl *gomock.Controller) *MockTLSExtensionHandler {
|
||||
mock := &MockTLSExtensionHandler{ctrl: ctrl}
|
||||
mock.recorder = &MockTLSExtensionHandlerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockTLSExtensionHandler) EXPECT() *MockTLSExtensionHandlerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetPeerParams mocks base method
|
||||
func (m *MockTLSExtensionHandler) GetPeerParams() <-chan handshake.TransportParameters {
|
||||
ret := m.ctrl.Call(m, "GetPeerParams")
|
||||
ret0, _ := ret[0].(<-chan handshake.TransportParameters)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetPeerParams indicates an expected call of GetPeerParams
|
||||
func (mr *MockTLSExtensionHandlerMockRecorder) GetPeerParams() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerParams", reflect.TypeOf((*MockTLSExtensionHandler)(nil).GetPeerParams))
|
||||
}
|
||||
|
||||
// Receive mocks base method
|
||||
func (m *MockTLSExtensionHandler) Receive(arg0 mint.HandshakeType, arg1 *mint.ExtensionList) error {
|
||||
ret := m.ctrl.Call(m, "Receive", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Receive indicates an expected call of Receive
|
||||
func (mr *MockTLSExtensionHandlerMockRecorder) Receive(arg0, arg1 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Receive), arg0, arg1)
|
||||
}
|
||||
|
||||
// Send mocks base method
|
||||
func (m *MockTLSExtensionHandler) Send(arg0 mint.HandshakeType, arg1 *mint.ExtensionList) error {
|
||||
ret := m.ctrl.Call(m, "Send", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Send indicates an expected call of Send
|
||||
func (mr *MockTLSExtensionHandlerMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Send), arg0, arg1)
|
||||
}
|
24
vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go
generated
vendored
24
vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go
generated
vendored
|
@ -7,22 +7,22 @@ type EncryptionLevel int
|
|||
const (
|
||||
// EncryptionUnspecified is a not specified encryption level
|
||||
EncryptionUnspecified EncryptionLevel = iota
|
||||
// EncryptionUnencrypted is not encrypted
|
||||
EncryptionUnencrypted
|
||||
// EncryptionSecure is encrypted, but not forward secure
|
||||
EncryptionSecure
|
||||
// EncryptionForwardSecure is forward secure
|
||||
EncryptionForwardSecure
|
||||
// EncryptionInitial is the Initial encryption level
|
||||
EncryptionInitial
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake
|
||||
// Encryption1RTT is the 1-RTT encryption level
|
||||
Encryption1RTT
|
||||
)
|
||||
|
||||
func (e EncryptionLevel) String() string {
|
||||
switch e {
|
||||
case EncryptionUnencrypted:
|
||||
return "unencrypted"
|
||||
case EncryptionSecure:
|
||||
return "encrypted (not forward-secure)"
|
||||
case EncryptionForwardSecure:
|
||||
return "forward-secure"
|
||||
case EncryptionInitial:
|
||||
return "Initial"
|
||||
case EncryptionHandshake:
|
||||
return "Handshake"
|
||||
case Encryption1RTT:
|
||||
return "1-RTT"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
|
|
@ -8,17 +8,13 @@ func InferPacketNumber(
|
|||
version VersionNumber,
|
||||
) PacketNumber {
|
||||
var epochDelta PacketNumber
|
||||
if version.UsesVarintPacketNumbers() {
|
||||
switch packetNumberLength {
|
||||
case PacketNumberLen1:
|
||||
epochDelta = PacketNumber(1) << 7
|
||||
case PacketNumberLen2:
|
||||
epochDelta = PacketNumber(1) << 14
|
||||
case PacketNumberLen4:
|
||||
epochDelta = PacketNumber(1) << 30
|
||||
}
|
||||
} else {
|
||||
epochDelta = PacketNumber(1) << (uint8(packetNumberLength) * 8)
|
||||
switch packetNumberLength {
|
||||
case PacketNumberLen1:
|
||||
epochDelta = PacketNumber(1) << 7
|
||||
case PacketNumberLen2:
|
||||
epochDelta = PacketNumber(1) << 14
|
||||
case PacketNumberLen4:
|
||||
epochDelta = PacketNumber(1) << 30
|
||||
}
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
prevEpochBegin := epoch - epochDelta
|
||||
|
@ -48,8 +44,7 @@ func delta(a, b PacketNumber) PacketNumber {
|
|||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
|
||||
diff := uint64(packetNumber - leastUnacked)
|
||||
if version.UsesVarintPacketNumbers() && diff < (1<<(14-1)) ||
|
||||
!version.UsesVarintPacketNumbers() && diff < (1<<(16-1)) {
|
||||
if diff < (1 << (14 - 1)) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
return PacketNumberLen4
|
||||
|
@ -63,8 +58,5 @@ func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen {
|
|||
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
if packetNumber < (1 << (uint8(PacketNumberLen4) * 8)) {
|
||||
return PacketNumberLen4
|
||||
}
|
||||
return PacketNumberLen6
|
||||
return PacketNumberLen4
|
||||
}
|
||||
|
|
|
@ -8,6 +8,9 @@ const MaxPacketSizeIPv4 = 1252
|
|||
// MaxPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
|
||||
const MaxPacketSizeIPv6 = 1232
|
||||
|
||||
// MinStatelessResetSize is the minimum size of a stateless reset packet
|
||||
const MinStatelessResetSize = 1 + 20 + 16
|
||||
|
||||
// NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet
|
||||
// This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames
|
||||
const NonForwardSecurePacketSizeReduction = 50
|
||||
|
@ -24,38 +27,22 @@ const InitialCongestionWindow ByteCount = 32 * DefaultTCPMSS
|
|||
// session queues for later until it sends a public reset.
|
||||
const MaxUndecryptablePackets = 10
|
||||
|
||||
// PublicResetTimeout is the time to wait before sending a Public Reset when receiving too many undecryptable packets during the handshake
|
||||
// This timeout allows the Go scheduler to switch to the Go rountine that reads the crypto stream and to escalate the crypto
|
||||
const PublicResetTimeout = 500 * time.Millisecond
|
||||
|
||||
// ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data
|
||||
// This is the value that Google servers are using
|
||||
const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB
|
||||
|
||||
// ReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data
|
||||
// This is the value that Google servers are using
|
||||
const ReceiveConnectionFlowControlWindow = (1 << 10) * 48 // 48 kB
|
||||
|
||||
// DefaultMaxReceiveStreamFlowControlWindowServer is the default maximum stream-level flow control window for receiving data, for the server
|
||||
// This is the value that Google servers are using
|
||||
const DefaultMaxReceiveStreamFlowControlWindowServer = 1 * (1 << 20) // 1 MB
|
||||
|
||||
// DefaultMaxReceiveConnectionFlowControlWindowServer is the default connection-level flow control window for receiving data, for the server
|
||||
// This is the value that Google servers are using
|
||||
const DefaultMaxReceiveConnectionFlowControlWindowServer = 1.5 * (1 << 20) // 1.5 MB
|
||||
|
||||
// DefaultMaxReceiveStreamFlowControlWindowClient is the default maximum stream-level flow control window for receiving data, for the client
|
||||
// This is the value that Chromium is using
|
||||
const DefaultMaxReceiveStreamFlowControlWindowClient = 6 * (1 << 20) // 6 MB
|
||||
|
||||
// DefaultMaxReceiveConnectionFlowControlWindowClient is the default connection-level flow control window for receiving data, for the client
|
||||
// This is the value that Google servers are using
|
||||
const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15 MB
|
||||
|
||||
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
|
||||
// This is the value that Chromium is using
|
||||
const ConnectionFlowControlMultiplier = 1.5
|
||||
|
||||
// InitialMaxStreamData is the stream-level flow control window for receiving data
|
||||
const InitialMaxStreamData = (1 << 10) * 512 // 512 kb
|
||||
|
||||
// InitialMaxData is the connection-level flow control window for receiving data
|
||||
const InitialMaxData = ConnectionFlowControlMultiplier * InitialMaxStreamData
|
||||
|
||||
// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data, for the server
|
||||
const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB
|
||||
|
||||
// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data, for the server
|
||||
const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 12 MB
|
||||
|
||||
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
|
||||
const WindowUpdateThreshold = 0.25
|
||||
|
||||
|
@ -65,12 +52,6 @@ const DefaultMaxIncomingStreams = 100
|
|||
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
|
||||
const DefaultMaxIncomingUniStreams = 100
|
||||
|
||||
// MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used.
|
||||
const MaxStreamsMultiplier = 1.1
|
||||
|
||||
// MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used.
|
||||
const MaxStreamsMinimumIncrement = 10
|
||||
|
||||
// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed.
|
||||
const MaxSessionUnprocessedPackets = defaultMaxCongestionWindowPackets
|
||||
|
||||
|
@ -103,15 +84,9 @@ const MaxNonRetransmittableAcks = 19
|
|||
// prevents DoS attacks against the streamFrameSorter
|
||||
const MaxStreamFrameSorterGaps = 1000
|
||||
|
||||
// CryptoMaxParams is the upper limit for the number of parameters in a crypto message.
|
||||
// Value taken from Chrome.
|
||||
const CryptoMaxParams = 128
|
||||
|
||||
// CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message.
|
||||
const CryptoParameterMaxLength = 4000
|
||||
|
||||
// EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX.
|
||||
const EphermalKeyLifetime = time.Minute
|
||||
// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
|
||||
// This limits the size of the ClientHello and Certificates that can be received.
|
||||
const MaxCryptoStreamOffset = 16 * (1 << 10)
|
||||
|
||||
// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout
|
||||
const MinRemoteIdleTimeout = 5 * time.Second
|
||||
|
@ -122,12 +97,9 @@ const DefaultIdleTimeout = 30 * time.Second
|
|||
// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
|
||||
const DefaultHandshakeTimeout = 10 * time.Second
|
||||
|
||||
// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed
|
||||
// RetiredConnectionIDDeleteTimeout is the time we keep closed sessions around in order to retransmit the CONNECTION_CLOSE.
|
||||
// after this time all information about the old connection will be deleted
|
||||
const ClosedSessionDeleteTimeout = time.Minute
|
||||
|
||||
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
|
||||
const NumCachedCertificates = 128
|
||||
const RetiredConnectionIDDeleteTimeout = 5 * time.Second
|
||||
|
||||
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
|
||||
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
|
||||
|
@ -135,7 +107,7 @@ const NumCachedCertificates = 128
|
|||
// 2. it reduces the head-of-line blocking, when a packet is lost
|
||||
const MinStreamFrameSize ByteCount = 128
|
||||
|
||||
// MaxAckFrameSize is the maximum size for an (IETF QUIC) ACK frame that we write
|
||||
// MaxAckFrameSize is the maximum size for an ACK frame that we write
|
||||
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
|
||||
// The MaxAckFrameSize should be large enough to encode many ACK range,
|
||||
// but must ensure that a maximum size ACK frame fits into one packet.
|
||||
|
@ -149,6 +121,3 @@ const MinPacingDelay time.Duration = 100 * time.Microsecond
|
|||
// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
|
||||
// if no other value is configured.
|
||||
const DefaultConnectionIDLength = 4
|
||||
|
||||
// MaxRetries is the maximum number of Retries a client will do before failing the connection.
|
||||
const MaxRetries = 3
|
|
@ -19,11 +19,9 @@ const (
|
|||
PacketNumberLen2 PacketNumberLen = 2
|
||||
// PacketNumberLen4 is a packet number length of 4 bytes
|
||||
PacketNumberLen4 PacketNumberLen = 4
|
||||
// PacketNumberLen6 is a packet number length of 6 bytes
|
||||
PacketNumberLen6 PacketNumberLen = 6
|
||||
)
|
||||
|
||||
// The PacketType is the Long Header Type (only used for the IETF draft header format)
|
||||
// The PacketType is the Long Header Type
|
||||
type PacketType uint8
|
||||
|
||||
const (
|
||||
|
@ -71,10 +69,7 @@ const MaxReceivePacketSize ByteCount = 1452 - 64
|
|||
// Used in QUIC for congestion window computations in bytes.
|
||||
const DefaultTCPMSS ByteCount = 1460
|
||||
|
||||
// MinClientHelloSize is the minimum size the server expects an inchoate CHLO to have (in gQUIC)
|
||||
const MinClientHelloSize = 1024
|
||||
|
||||
// MinInitialPacketSize is the minimum size an Initial packet (in IETF QUIC) is required to have.
|
||||
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
|
||||
const MinInitialPacketSize = 1200
|
||||
|
||||
// MaxClientHellos is the maximum number of times we'll send a client hello
|
||||
|
@ -83,8 +78,5 @@ const MinInitialPacketSize = 1200
|
|||
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
|
||||
const MaxClientHellos = 3
|
||||
|
||||
// ConnectionIDLenGQUIC is the length of the source Connection ID used on gQUIC QUIC packets.
|
||||
const ConnectionIDLenGQUIC = 8
|
||||
|
||||
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
|
||||
const MinConnectionIDLenInitial = 8
|
||||
|
|
|
@ -3,34 +3,65 @@ package protocol
|
|||
// A StreamID in QUIC
|
||||
type StreamID uint64
|
||||
|
||||
// MaxBidiStreamID is the highest stream ID that the peer is allowed to open,
|
||||
// when it is allowed to open numStreams bidirectional streams.
|
||||
// It is only valid for IETF QUIC.
|
||||
func MaxBidiStreamID(numStreams int, pers Perspective) StreamID {
|
||||
// StreamType encodes if this is a unidirectional or bidirectional stream
|
||||
type StreamType uint8
|
||||
|
||||
const (
|
||||
// StreamTypeUni is a unidirectional stream
|
||||
StreamTypeUni StreamType = iota
|
||||
// StreamTypeBidi is a bidirectional stream
|
||||
StreamTypeBidi
|
||||
)
|
||||
|
||||
// InitiatedBy says if the stream was initiated by the client or by the server
|
||||
func (s StreamID) InitiatedBy() Perspective {
|
||||
if s%2 == 0 {
|
||||
return PerspectiveClient
|
||||
}
|
||||
return PerspectiveServer
|
||||
}
|
||||
|
||||
//Type says if this is a unidirectional or bidirectional stream
|
||||
func (s StreamID) Type() StreamType {
|
||||
if s%4 >= 2 {
|
||||
return StreamTypeUni
|
||||
}
|
||||
return StreamTypeBidi
|
||||
}
|
||||
|
||||
// StreamNum returns how many streams in total are below this
|
||||
// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9)
|
||||
func (s StreamID) StreamNum() uint64 {
|
||||
return uint64(s/4) + 1
|
||||
}
|
||||
|
||||
// MaxStreamID is the highest stream ID that a peer is allowed to open,
|
||||
// when it is allowed to open numStreams.
|
||||
func MaxStreamID(stype StreamType, numStreams uint64, pers Perspective) StreamID {
|
||||
if numStreams == 0 {
|
||||
return 0
|
||||
}
|
||||
var first StreamID
|
||||
if pers == PerspectiveClient {
|
||||
first = 1
|
||||
} else {
|
||||
first = 4
|
||||
switch stype {
|
||||
case StreamTypeBidi:
|
||||
switch pers {
|
||||
case PerspectiveClient:
|
||||
first = 0
|
||||
case PerspectiveServer:
|
||||
first = 1
|
||||
}
|
||||
case StreamTypeUni:
|
||||
switch pers {
|
||||
case PerspectiveClient:
|
||||
first = 2
|
||||
case PerspectiveServer:
|
||||
first = 3
|
||||
}
|
||||
}
|
||||
return first + 4*StreamID(numStreams-1)
|
||||
}
|
||||
|
||||
// MaxUniStreamID is the highest stream ID that the peer is allowed to open,
|
||||
// when it is allowed to open numStreams unidirectional streams.
|
||||
// It is only valid for IETF QUIC.
|
||||
func MaxUniStreamID(numStreams int, pers Perspective) StreamID {
|
||||
if numStreams == 0 {
|
||||
return 0
|
||||
}
|
||||
var first StreamID
|
||||
if pers == PerspectiveClient {
|
||||
first = 3
|
||||
} else {
|
||||
first = 2
|
||||
}
|
||||
return first + 4*StreamID(numStreams-1)
|
||||
// FirstStream returns the first valid stream ID
|
||||
func FirstStream(stype StreamType, pers Perspective) StreamID {
|
||||
return MaxStreamID(stype, 1, pers)
|
||||
}
|
||||
|
|
|
@ -18,32 +18,18 @@ const (
|
|||
|
||||
// The version numbers, making grepping easier
|
||||
const (
|
||||
Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9
|
||||
Version43 VersionNumber = gquicVersion0 + 4*0x100 + 0x3
|
||||
Version44 VersionNumber = gquicVersion0 + 4*0x100 + 0x4
|
||||
VersionTLS VersionNumber = 101
|
||||
VersionWhatever VersionNumber = 0 // for when the version doesn't matter
|
||||
VersionWhatever VersionNumber = 1 // for when the version doesn't matter
|
||||
VersionUnknown VersionNumber = math.MaxUint32
|
||||
|
||||
VersionMilestone0_10_0 VersionNumber = 0x51474f02
|
||||
)
|
||||
|
||||
// SupportedVersions lists the versions that the server supports
|
||||
// must be in sorted descending order
|
||||
var SupportedVersions = []VersionNumber{
|
||||
Version44,
|
||||
Version43,
|
||||
Version39,
|
||||
}
|
||||
var SupportedVersions = []VersionNumber{VersionTLS}
|
||||
|
||||
// IsValidVersion says if the version is known to quic-go
|
||||
func IsValidVersion(v VersionNumber) bool {
|
||||
return v == VersionTLS || v == VersionMilestone0_10_0 || IsSupportedVersion(SupportedVersions, v)
|
||||
}
|
||||
|
||||
// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake
|
||||
func (vn VersionNumber) UsesTLS() bool {
|
||||
return !vn.isGQUIC()
|
||||
return v == VersionTLS || IsSupportedVersion(SupportedVersions, v)
|
||||
}
|
||||
|
||||
func (vn VersionNumber) String() string {
|
||||
|
@ -52,8 +38,6 @@ func (vn VersionNumber) String() string {
|
|||
return "whatever"
|
||||
case VersionUnknown:
|
||||
return "unknown"
|
||||
case VersionMilestone0_10_0:
|
||||
return "quic-go Milestone 0.10.0"
|
||||
case VersionTLS:
|
||||
return "TLS dev version (WIP)"
|
||||
default:
|
||||
|
@ -66,61 +50,9 @@ func (vn VersionNumber) String() string {
|
|||
|
||||
// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters
|
||||
func (vn VersionNumber) ToAltSvc() string {
|
||||
if vn.isGQUIC() {
|
||||
return fmt.Sprintf("%d", vn.toGQUICVersion())
|
||||
}
|
||||
return fmt.Sprintf("%d", vn)
|
||||
}
|
||||
|
||||
// CryptoStreamID gets the Stream ID of the crypto stream
|
||||
func (vn VersionNumber) CryptoStreamID() StreamID {
|
||||
if vn.isGQUIC() {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// UsesIETFFrameFormat tells if this version uses the IETF frame format
|
||||
func (vn VersionNumber) UsesIETFFrameFormat() bool {
|
||||
return !vn.isGQUIC()
|
||||
}
|
||||
|
||||
// UsesIETFHeaderFormat tells if this version uses the IETF header format
|
||||
func (vn VersionNumber) UsesIETFHeaderFormat() bool {
|
||||
return !vn.isGQUIC() || vn >= Version44
|
||||
}
|
||||
|
||||
// UsesLengthInHeader tells if this version uses the Length field in the IETF header
|
||||
func (vn VersionNumber) UsesLengthInHeader() bool {
|
||||
return !vn.isGQUIC()
|
||||
}
|
||||
|
||||
// UsesTokenInHeader tells if this version uses the Token field in the IETF header
|
||||
func (vn VersionNumber) UsesTokenInHeader() bool {
|
||||
return !vn.isGQUIC()
|
||||
}
|
||||
|
||||
// UsesStopWaitingFrames tells if this version uses STOP_WAITING frames
|
||||
func (vn VersionNumber) UsesStopWaitingFrames() bool {
|
||||
return vn.isGQUIC() && vn <= Version43
|
||||
}
|
||||
|
||||
// UsesVarintPacketNumbers tells if this version uses 7/14/30 bit packet numbers
|
||||
func (vn VersionNumber) UsesVarintPacketNumbers() bool {
|
||||
return !vn.isGQUIC()
|
||||
}
|
||||
|
||||
// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control
|
||||
func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool {
|
||||
if id == vn.CryptoStreamID() {
|
||||
return false
|
||||
}
|
||||
if vn.isGQUIC() && id == 3 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (vn VersionNumber) isGQUIC() bool {
|
||||
return vn > gquicVersion0 && vn <= maxGquicVersion
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
)
|
||||
|
||||
// ErrorCode can be used as a normal error without reason.
|
||||
type ErrorCode uint32
|
||||
type ErrorCode uint16
|
||||
|
||||
func (e ErrorCode) Error() string {
|
||||
return e.String()
|
|
@ -13,13 +13,6 @@ type ByteOrder interface {
|
|||
ReadUint16(io.ByteReader) (uint16, error)
|
||||
|
||||
WriteUint64(*bytes.Buffer, uint64)
|
||||
WriteUint56(*bytes.Buffer, uint64)
|
||||
WriteUint48(*bytes.Buffer, uint64)
|
||||
WriteUint40(*bytes.Buffer, uint64)
|
||||
WriteUint32(*bytes.Buffer, uint32)
|
||||
WriteUint24(*bytes.Buffer, uint32)
|
||||
WriteUint16(*bytes.Buffer, uint16)
|
||||
|
||||
ReadUfloat16(io.ByteReader) (uint64, error)
|
||||
WriteUfloat16(*bytes.Buffer, uint64)
|
||||
}
|
||||
|
|
50
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
50
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
|
@ -2,7 +2,6 @@ package utils
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
|
@ -97,61 +96,12 @@ func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
|
|||
})
|
||||
}
|
||||
|
||||
// WriteUint56 writes 56 bit of a uint64
|
||||
func (bigEndian) WriteUint56(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 56) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
|
||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint48 writes 48 bit of a uint64
|
||||
func (bigEndian) WriteUint48(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 48) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i >> 40), uint8(i >> 32),
|
||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint40 writes 40 bit of a uint64
|
||||
func (bigEndian) WriteUint40(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 40) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i >> 32),
|
||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint32 writes a uint32
|
||||
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
|
||||
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
|
||||
// WriteUint24 writes 24 bit of a uint32
|
||||
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
|
||||
if i >= (1 << 24) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
|
||||
}
|
||||
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
|
||||
// WriteUint16 writes a uint16
|
||||
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
|
||||
b.Write([]byte{uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
|
||||
func (l bigEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
|
||||
return readUfloat16(b, l)
|
||||
}
|
||||
|
||||
func (l bigEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
|
||||
writeUfloat16(b, l, val)
|
||||
}
|
||||
|
|
157
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_little_endian.go
generated
vendored
157
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_little_endian.go
generated
vendored
|
@ -1,157 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// LittleEndian is the little-endian implementation of ByteOrder.
|
||||
var LittleEndian ByteOrder = littleEndian{}
|
||||
|
||||
type littleEndian struct{}
|
||||
|
||||
var _ ByteOrder = &littleEndian{}
|
||||
|
||||
// ReadUintN reads N bytes
|
||||
func (littleEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
|
||||
var res uint64
|
||||
for i := uint8(0); i < length; i++ {
|
||||
bt, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
res ^= uint64(bt) << (i * 8)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// ReadUint64 reads a uint64
|
||||
func (littleEndian) ReadUint64(b io.ByteReader) (uint64, error) {
|
||||
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
|
||||
var err error
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b3, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b4, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b5, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b6, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b7, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b8, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
|
||||
}
|
||||
|
||||
// ReadUint32 reads a uint32
|
||||
func (littleEndian) ReadUint32(b io.ByteReader) (uint32, error) {
|
||||
var b1, b2, b3, b4 uint8
|
||||
var err error
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b3, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b4, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
|
||||
}
|
||||
|
||||
// ReadUint16 reads a uint16
|
||||
func (littleEndian) ReadUint16(b io.ByteReader) (uint16, error) {
|
||||
var b1, b2 uint8
|
||||
var err error
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint16(b1) + uint16(b2)<<8, nil
|
||||
}
|
||||
|
||||
// WriteUint64 writes a uint64
|
||||
func (littleEndian) WriteUint64(b *bytes.Buffer, i uint64) {
|
||||
b.Write([]byte{
|
||||
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
|
||||
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), uint8(i >> 56),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint56 writes 56 bit of a uint64
|
||||
func (littleEndian) WriteUint56(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 56) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
|
||||
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint48 writes 48 bit of a uint64
|
||||
func (littleEndian) WriteUint48(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 48) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
|
||||
uint8(i >> 32), uint8(i >> 40),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint40 writes 40 bit of a uint64
|
||||
func (littleEndian) WriteUint40(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 40) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i), uint8(i >> 8), uint8(i >> 16),
|
||||
uint8(i >> 24), uint8(i >> 32),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint32 writes a uint32
|
||||
func (littleEndian) WriteUint32(b *bytes.Buffer, i uint32) {
|
||||
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24)})
|
||||
}
|
||||
|
||||
// WriteUint24 writes 24 bit of a uint32
|
||||
func (littleEndian) WriteUint24(b *bytes.Buffer, i uint32) {
|
||||
if i >= (1 << 24) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
|
||||
}
|
||||
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16)})
|
||||
}
|
||||
|
||||
// WriteUint16 writes a uint16
|
||||
func (littleEndian) WriteUint16(b *bytes.Buffer, i uint16) {
|
||||
b.Write([]byte{uint8(i), uint8(i >> 8)})
|
||||
}
|
||||
|
||||
func (l littleEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
|
||||
return readUfloat16(b, l)
|
||||
}
|
||||
|
||||
func (l littleEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
|
||||
writeUfloat16(b, l, val)
|
||||
}
|
|
@ -1,86 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math"
|
||||
)
|
||||
|
||||
// We define an unsigned 16-bit floating point value, inspired by IEEE floats
|
||||
// (http://en.wikipedia.org/wiki/Half_precision_floating-point_format),
|
||||
// with 5-bit exponent (bias 1), 11-bit mantissa (effective 12 with hidden
|
||||
// bit) and denormals, but without signs, transfinites or fractions. Wire format
|
||||
// 16 bits (little-endian byte order) are split into exponent (high 5) and
|
||||
// mantissa (low 11) and decoded as:
|
||||
// uint64_t value;
|
||||
// if (exponent == 0) value = mantissa;
|
||||
// else value = (mantissa | 1 << 11) << (exponent - 1)
|
||||
const uFloat16ExponentBits = 5
|
||||
const uFloat16MaxExponent = (1 << uFloat16ExponentBits) - 2 // 30
|
||||
const uFloat16MantissaBits = 16 - uFloat16ExponentBits // 11
|
||||
const uFloat16MantissaEffectiveBits = uFloat16MantissaBits + 1 // 12
|
||||
const uFloat16MaxValue = ((uint64(1) << uFloat16MantissaEffectiveBits) - 1) << uFloat16MaxExponent // 0x3FFC0000000
|
||||
|
||||
// readUfloat16 reads a float in the QUIC-float16 format and returns its uint64 representation
|
||||
func readUfloat16(b io.ByteReader, byteOrder ByteOrder) (uint64, error) {
|
||||
val, err := byteOrder.ReadUint16(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
res := uint64(val)
|
||||
|
||||
if res < (1 << uFloat16MantissaEffectiveBits) {
|
||||
// Fast path: either the value is denormalized (no hidden bit), or
|
||||
// normalized (hidden bit set, exponent offset by one) with exponent zero.
|
||||
// Zero exponent offset by one sets the bit exactly where the hidden bit is.
|
||||
// So in both cases the value encodes itself.
|
||||
return res, nil
|
||||
}
|
||||
|
||||
exponent := val >> uFloat16MantissaBits // No sign extend on uint!
|
||||
// After the fast pass, the exponent is at least one (offset by one).
|
||||
// Un-offset the exponent.
|
||||
exponent--
|
||||
// Here we need to clear the exponent and set the hidden bit. We have already
|
||||
// decremented the exponent, so when we subtract it, it leaves behind the
|
||||
// hidden bit.
|
||||
res -= uint64(exponent) << uFloat16MantissaBits
|
||||
res <<= exponent
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// writeUfloat16 writes a float in the QUIC-float16 format from its uint64 representation
|
||||
func writeUfloat16(b *bytes.Buffer, byteOrder ByteOrder, value uint64) {
|
||||
var result uint16
|
||||
if value < (uint64(1) << uFloat16MantissaEffectiveBits) {
|
||||
// Fast path: either the value is denormalized, or has exponent zero.
|
||||
// Both cases are represented by the value itself.
|
||||
result = uint16(value)
|
||||
} else if value >= uFloat16MaxValue {
|
||||
// Value is out of range; clamp it to the maximum representable.
|
||||
result = math.MaxUint16
|
||||
} else {
|
||||
// The highest bit is between position 13 and 42 (zero-based), which
|
||||
// corresponds to exponent 1-30. In the output, mantissa is from 0 to 10,
|
||||
// hidden bit is 11 and exponent is 11 to 15. Shift the highest bit to 11
|
||||
// and count the shifts.
|
||||
exponent := uint16(0)
|
||||
for offset := uint16(16); offset > 0; offset /= 2 {
|
||||
// Right-shift the value until the highest bit is in position 11.
|
||||
// For offset of 16, 8, 4, 2 and 1 (binary search over 1-30),
|
||||
// shift if the bit is at or above 11 + offset.
|
||||
if value >= (uint64(1) << (uFloat16MantissaBits + offset)) {
|
||||
exponent += offset
|
||||
value >>= offset
|
||||
}
|
||||
}
|
||||
|
||||
// Hidden bit (position 11) is set. We should remove it and increment the
|
||||
// exponent. Equivalently, we just add it to the exponent.
|
||||
// This hides the bit.
|
||||
result = (uint16(value) + (exponent << uFloat16MantissaBits))
|
||||
}
|
||||
|
||||
byteOrder.WriteUint16(b, result)
|
||||
}
|
|
@ -13,29 +13,21 @@ import (
|
|||
// TODO: use the value sent in the transport parameters
|
||||
const ackDelayExponent = 3
|
||||
|
||||
var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
|
||||
|
||||
// An AckFrame is an ACK frame
|
||||
type AckFrame struct {
|
||||
AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last
|
||||
DelayTime time.Duration
|
||||
}
|
||||
|
||||
func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
|
||||
return parseAckOrAckEcnFrame(r, false, version)
|
||||
}
|
||||
|
||||
func parseAckEcnFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
|
||||
return parseAckOrAckEcnFrame(r, true, version)
|
||||
}
|
||||
|
||||
// parseAckFrame reads an ACK frame
|
||||
func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNumber) (*AckFrame, error) {
|
||||
if !version.UsesIETFFrameFormat() {
|
||||
return parseAckFrameLegacy(r, version)
|
||||
}
|
||||
|
||||
if _, err := r.ReadByte(); err != nil {
|
||||
func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
|
||||
typeByte, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ecn := typeByte&0x1 > 0
|
||||
|
||||
frame := &AckFrame{}
|
||||
|
||||
|
@ -50,14 +42,6 @@ func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNu
|
|||
}
|
||||
frame.DelayTime = time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
|
||||
|
||||
if ecn {
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := utils.ReadVarInt(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
numBlocks, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -103,16 +87,22 @@ func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNu
|
|||
if !frame.validateAckRanges() {
|
||||
return nil, errInvalidAckRanges
|
||||
}
|
||||
|
||||
// parse (and skip) the ECN section
|
||||
if ecn {
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := utils.ReadVarInt(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
// Write writes an ACK frame.
|
||||
func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
if !version.UsesIETFFrameFormat() {
|
||||
return f.writeLegacy(b, version)
|
||||
}
|
||||
|
||||
b.WriteByte(0x0d)
|
||||
b.WriteByte(0x2)
|
||||
utils.WriteVarInt(b, uint64(f.LargestAcked()))
|
||||
utils.WriteVarInt(b, encodeAckDelay(f.DelayTime))
|
||||
|
||||
|
@ -134,10 +124,6 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
|
|||
|
||||
// Length of a written frame
|
||||
func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||
if !version.UsesIETFFrameFormat() {
|
||||
return f.lengthLegacy(version)
|
||||
}
|
||||
|
||||
largestAcked := f.AckRanges[0].Largest
|
||||
numRanges := f.numEncodableAckRanges()
|
||||
|
||||
|
|
|
@ -1,364 +0,0 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
|
||||
|
||||
func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, error) {
|
||||
frame := &AckFrame{}
|
||||
|
||||
typeByte, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hasMissingRanges := typeByte&0x20 == 0x20
|
||||
largestAckedLen := 2 * ((typeByte & 0x0C) >> 2)
|
||||
if largestAckedLen == 0 {
|
||||
largestAckedLen = 1
|
||||
}
|
||||
|
||||
missingSequenceNumberDeltaLen := 2 * (typeByte & 0x03)
|
||||
if missingSequenceNumberDeltaLen == 0 {
|
||||
missingSequenceNumberDeltaLen = 1
|
||||
}
|
||||
|
||||
la, err := utils.BigEndian.ReadUintN(r, largestAckedLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
largestAcked := protocol.PacketNumber(la)
|
||||
|
||||
delay, err := utils.BigEndian.ReadUfloat16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.DelayTime = time.Duration(delay) * time.Microsecond
|
||||
|
||||
var numAckBlocks uint8
|
||||
if hasMissingRanges {
|
||||
numAckBlocks, err = r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if hasMissingRanges && numAckBlocks == 0 {
|
||||
return nil, errInvalidAckRanges
|
||||
}
|
||||
|
||||
abl, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ackBlockLength := protocol.PacketNumber(abl)
|
||||
if largestAcked > 0 && ackBlockLength < 1 {
|
||||
return nil, errors.New("invalid first ACK range")
|
||||
}
|
||||
|
||||
if ackBlockLength > largestAcked+1 {
|
||||
return nil, errInvalidAckRanges
|
||||
}
|
||||
|
||||
if hasMissingRanges {
|
||||
ackRange := AckRange{
|
||||
Smallest: largestAcked - ackBlockLength + 1,
|
||||
Largest: largestAcked,
|
||||
}
|
||||
frame.AckRanges = append(frame.AckRanges, ackRange)
|
||||
|
||||
var inLongBlock bool
|
||||
var lastRangeComplete bool
|
||||
for i := uint8(0); i < numAckBlocks; i++ {
|
||||
var gap uint8
|
||||
gap, err = r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
abl, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ackBlockLength := protocol.PacketNumber(abl)
|
||||
|
||||
if inLongBlock {
|
||||
frame.AckRanges[len(frame.AckRanges)-1].Smallest -= protocol.PacketNumber(gap) + ackBlockLength
|
||||
frame.AckRanges[len(frame.AckRanges)-1].Largest -= protocol.PacketNumber(gap)
|
||||
} else {
|
||||
lastRangeComplete = false
|
||||
ackRange := AckRange{
|
||||
Largest: frame.AckRanges[len(frame.AckRanges)-1].Smallest - protocol.PacketNumber(gap) - 1,
|
||||
}
|
||||
ackRange.Smallest = ackRange.Largest - ackBlockLength + 1
|
||||
frame.AckRanges = append(frame.AckRanges, ackRange)
|
||||
}
|
||||
|
||||
if ackBlockLength > 0 {
|
||||
lastRangeComplete = true
|
||||
}
|
||||
inLongBlock = (ackBlockLength == 0)
|
||||
}
|
||||
|
||||
// if the last range was not complete, First and Last make no sense
|
||||
// remove the range from frame.AckRanges
|
||||
if !lastRangeComplete {
|
||||
frame.AckRanges = frame.AckRanges[:len(frame.AckRanges)-1]
|
||||
}
|
||||
} else {
|
||||
frame.AckRanges = make([]AckRange, 1)
|
||||
if largestAcked != 0 {
|
||||
frame.AckRanges[0].Largest = largestAcked
|
||||
frame.AckRanges[0].Smallest = largestAcked + 1 - ackBlockLength
|
||||
}
|
||||
}
|
||||
|
||||
if !frame.validateAckRanges() {
|
||||
return nil, errInvalidAckRanges
|
||||
}
|
||||
|
||||
var numTimestamp byte
|
||||
numTimestamp, err = r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if numTimestamp > 0 {
|
||||
// Delta Largest acked
|
||||
_, err = r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// First Timestamp
|
||||
_, err = utils.BigEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := 0; i < int(numTimestamp)-1; i++ {
|
||||
// Delta Largest acked
|
||||
_, err = r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Time Since Previous Timestamp
|
||||
_, err = utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
||||
largestAcked := f.LargestAcked()
|
||||
largestAckedLen := protocol.GetPacketNumberLength(largestAcked)
|
||||
|
||||
typeByte := uint8(0x40)
|
||||
|
||||
if largestAckedLen != protocol.PacketNumberLen1 {
|
||||
typeByte ^= (uint8(largestAckedLen / 2)) << 2
|
||||
}
|
||||
|
||||
missingSequenceNumberDeltaLen := f.getMissingSequenceNumberDeltaLen()
|
||||
if missingSequenceNumberDeltaLen != protocol.PacketNumberLen1 {
|
||||
typeByte ^= (uint8(missingSequenceNumberDeltaLen / 2))
|
||||
}
|
||||
|
||||
if f.HasMissingRanges() {
|
||||
typeByte |= 0x20
|
||||
}
|
||||
|
||||
b.WriteByte(typeByte)
|
||||
|
||||
switch largestAckedLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
b.WriteByte(uint8(largestAcked))
|
||||
case protocol.PacketNumberLen2:
|
||||
utils.BigEndian.WriteUint16(b, uint16(largestAcked))
|
||||
case protocol.PacketNumberLen4:
|
||||
utils.BigEndian.WriteUint32(b, uint32(largestAcked))
|
||||
case protocol.PacketNumberLen6:
|
||||
utils.BigEndian.WriteUint48(b, uint64(largestAcked)&(1<<48-1))
|
||||
}
|
||||
|
||||
utils.BigEndian.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond))
|
||||
|
||||
var numRanges uint64
|
||||
var numRangesWritten uint64
|
||||
if f.HasMissingRanges() {
|
||||
numRanges = f.numWritableNackRanges()
|
||||
if numRanges > 0xFF {
|
||||
panic("AckFrame: Too many ACK ranges")
|
||||
}
|
||||
b.WriteByte(uint8(numRanges - 1))
|
||||
}
|
||||
|
||||
var firstAckBlockLength protocol.PacketNumber
|
||||
if !f.HasMissingRanges() {
|
||||
firstAckBlockLength = largestAcked - f.LowestAcked() + 1
|
||||
} else {
|
||||
firstAckBlockLength = largestAcked - f.AckRanges[0].Smallest + 1
|
||||
numRangesWritten++
|
||||
}
|
||||
|
||||
switch missingSequenceNumberDeltaLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
b.WriteByte(uint8(firstAckBlockLength))
|
||||
case protocol.PacketNumberLen2:
|
||||
utils.BigEndian.WriteUint16(b, uint16(firstAckBlockLength))
|
||||
case protocol.PacketNumberLen4:
|
||||
utils.BigEndian.WriteUint32(b, uint32(firstAckBlockLength))
|
||||
case protocol.PacketNumberLen6:
|
||||
utils.BigEndian.WriteUint48(b, uint64(firstAckBlockLength)&(1<<48-1))
|
||||
}
|
||||
|
||||
for i, ackRange := range f.AckRanges {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
length := ackRange.Largest - ackRange.Smallest + 1
|
||||
gap := f.AckRanges[i-1].Smallest - ackRange.Largest - 1
|
||||
|
||||
num := gap/0xFF + 1
|
||||
if gap%0xFF == 0 {
|
||||
num--
|
||||
}
|
||||
|
||||
if num == 1 {
|
||||
b.WriteByte(uint8(gap))
|
||||
switch missingSequenceNumberDeltaLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
b.WriteByte(uint8(length))
|
||||
case protocol.PacketNumberLen2:
|
||||
utils.BigEndian.WriteUint16(b, uint16(length))
|
||||
case protocol.PacketNumberLen4:
|
||||
utils.BigEndian.WriteUint32(b, uint32(length))
|
||||
case protocol.PacketNumberLen6:
|
||||
utils.BigEndian.WriteUint48(b, uint64(length)&(1<<48-1))
|
||||
}
|
||||
numRangesWritten++
|
||||
} else {
|
||||
for i := 0; i < int(num); i++ {
|
||||
var lengthWritten uint64
|
||||
var gapWritten uint8
|
||||
|
||||
if i == int(num)-1 { // last block
|
||||
lengthWritten = uint64(length)
|
||||
gapWritten = uint8(1 + ((gap - 1) % 255))
|
||||
} else {
|
||||
lengthWritten = 0
|
||||
gapWritten = 0xFF
|
||||
}
|
||||
|
||||
b.WriteByte(gapWritten)
|
||||
switch missingSequenceNumberDeltaLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
b.WriteByte(uint8(lengthWritten))
|
||||
case protocol.PacketNumberLen2:
|
||||
utils.BigEndian.WriteUint16(b, uint16(lengthWritten))
|
||||
case protocol.PacketNumberLen4:
|
||||
utils.BigEndian.WriteUint32(b, uint32(lengthWritten))
|
||||
case protocol.PacketNumberLen6:
|
||||
utils.BigEndian.WriteUint48(b, lengthWritten&(1<<48-1))
|
||||
}
|
||||
|
||||
numRangesWritten++
|
||||
}
|
||||
}
|
||||
|
||||
// this is needed if not all AckRanges can be written to the ACK frame (if there are more than 0xFF)
|
||||
if numRangesWritten >= numRanges {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if numRanges != numRangesWritten {
|
||||
return errors.New("BUG: Inconsistent number of ACK ranges written")
|
||||
}
|
||||
|
||||
b.WriteByte(0) // no timestamps
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *AckFrame) lengthLegacy(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp
|
||||
length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked()))
|
||||
|
||||
missingSequenceNumberDeltaLen := protocol.ByteCount(f.getMissingSequenceNumberDeltaLen())
|
||||
|
||||
if f.HasMissingRanges() {
|
||||
length += (1 + missingSequenceNumberDeltaLen) * protocol.ByteCount(f.numWritableNackRanges())
|
||||
} else {
|
||||
length += missingSequenceNumberDeltaLen
|
||||
}
|
||||
// we don't write
|
||||
return length
|
||||
}
|
||||
|
||||
// numWritableNackRanges calculates the number of ACK blocks that are about to be written
|
||||
// this number is different from len(f.AckRanges) for the case of long gaps (> 255 packets)
|
||||
func (f *AckFrame) numWritableNackRanges() uint64 {
|
||||
if len(f.AckRanges) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var numRanges uint64
|
||||
for i, ackRange := range f.AckRanges {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
lastAckRange := f.AckRanges[i-1]
|
||||
gap := lastAckRange.Smallest - ackRange.Largest - 1
|
||||
rangeLength := 1 + uint64(gap)/0xFF
|
||||
if uint64(gap)%0xFF == 0 {
|
||||
rangeLength--
|
||||
}
|
||||
|
||||
if numRanges+rangeLength < 0xFF {
|
||||
numRanges += rangeLength
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return numRanges + 1
|
||||
}
|
||||
|
||||
func (f *AckFrame) getMissingSequenceNumberDeltaLen() protocol.PacketNumberLen {
|
||||
var maxRangeLength protocol.PacketNumber
|
||||
|
||||
if f.HasMissingRanges() {
|
||||
for _, ackRange := range f.AckRanges {
|
||||
rangeLength := ackRange.Largest - ackRange.Smallest + 1
|
||||
if rangeLength > maxRangeLength {
|
||||
maxRangeLength = rangeLength
|
||||
}
|
||||
}
|
||||
} else {
|
||||
maxRangeLength = f.LargestAcked() - f.LowestAcked() + 1
|
||||
}
|
||||
|
||||
if maxRangeLength <= 0xFF {
|
||||
return protocol.PacketNumberLen1
|
||||
}
|
||||
if maxRangeLength <= 0xFFFF {
|
||||
return protocol.PacketNumberLen2
|
||||
}
|
||||
if maxRangeLength <= 0xFFFFFFFF {
|
||||
return protocol.PacketNumberLen4
|
||||
}
|
||||
|
||||
return protocol.PacketNumberLen6
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A BlockedFrame is a BLOCKED frame
|
||||
type BlockedFrame struct {
|
||||
Offset protocol.ByteCount
|
||||
}
|
||||
|
||||
// parseBlockedFrame parses a BLOCKED frame
|
||||
func parseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) {
|
||||
if _, err := r.ReadByte(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &BlockedFrame{
|
||||
Offset: protocol.ByteCount(offset),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
if !version.UsesIETFFrameFormat() {
|
||||
return (&blockedFrameLegacy{}).Write(b, version)
|
||||
}
|
||||
typeByte := uint8(0x08)
|
||||
b.WriteByte(typeByte)
|
||||
utils.WriteVarInt(b, uint64(f.Offset))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *BlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||
if !version.UsesIETFFrameFormat() {
|
||||
return 1 + 4
|
||||
}
|
||||
return 1 + utils.VarIntLen(uint64(f.Offset))
|
||||
}
|
37
vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go
generated
vendored
37
vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go
generated
vendored
|
@ -1,37 +0,0 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type blockedFrameLegacy struct {
|
||||
StreamID protocol.StreamID
|
||||
}
|
||||
|
||||
// parseBlockedFrameLegacy parses a BLOCKED frame (in gQUIC format)
|
||||
// The frame returned is
|
||||
// * a STREAM_BLOCKED frame, if the BLOCKED applies to a stream
|
||||
// * a BLOCKED frame, if the BLOCKED applies to the connection
|
||||
func parseBlockedFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) {
|
||||
if _, err := r.ReadByte(); err != nil { // read the TypeByte
|
||||
return nil, err
|
||||
}
|
||||
streamID, err := utils.BigEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if streamID == 0 {
|
||||
return &BlockedFrame{}, nil
|
||||
}
|
||||
return &StreamBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil
|
||||
}
|
||||
|
||||
//Write writes a BLOCKED frame
|
||||
func (f *blockedFrameLegacy) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
||||
b.WriteByte(0x05)
|
||||
utils.BigEndian.WriteUint32(b, uint32(f.StreamID))
|
||||
return nil
|
||||
}
|
91
vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go
generated
vendored
91
vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go
generated
vendored
|
@ -2,52 +2,43 @@ package wire
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// A ConnectionCloseFrame in QUIC
|
||||
// A ConnectionCloseFrame is a CONNECTION_CLOSE frame
|
||||
type ConnectionCloseFrame struct {
|
||||
ErrorCode qerr.ErrorCode
|
||||
ReasonPhrase string
|
||||
IsApplicationError bool
|
||||
ErrorCode qerr.ErrorCode
|
||||
ReasonPhrase string
|
||||
}
|
||||
|
||||
// parseConnectionCloseFrame reads a CONNECTION_CLOSE frame
|
||||
func parseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) {
|
||||
if _, err := r.ReadByte(); err != nil { // read the TypeByte
|
||||
typeByte, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var errorCode qerr.ErrorCode
|
||||
var reasonPhraseLen uint64
|
||||
if version.UsesIETFFrameFormat() {
|
||||
ec, err := utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
errorCode = qerr.ErrorCode(ec)
|
||||
reasonPhraseLen, err = utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
ec, err := utils.BigEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
errorCode = qerr.ErrorCode(ec)
|
||||
length, err := utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reasonPhraseLen = uint64(length)
|
||||
f := &ConnectionCloseFrame{IsApplicationError: typeByte == 0x1d}
|
||||
ec, err := utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.ErrorCode = qerr.ErrorCode(ec)
|
||||
// read the Frame Type, if this is not an application error
|
||||
if !f.IsApplicationError {
|
||||
if _, err := utils.ReadVarInt(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var reasonPhraseLen uint64
|
||||
reasonPhraseLen, err = utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// shortcut to prevent the unnecessary allocation of dataLen bytes
|
||||
// if the dataLen is larger than the remaining length of the packet
|
||||
// reading the whole reason phrase would result in EOF when attempting to READ
|
||||
|
@ -60,37 +51,31 @@ func parseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber)
|
|||
// this should never happen, since we already checked the reasonPhraseLen earlier
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ConnectionCloseFrame{
|
||||
ErrorCode: errorCode,
|
||||
ReasonPhrase: string(reasonPhrase),
|
||||
}, nil
|
||||
f.ReasonPhrase = string(reasonPhrase)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *ConnectionCloseFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||
if version.UsesIETFFrameFormat() {
|
||||
return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
|
||||
length := 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
|
||||
if !f.IsApplicationError {
|
||||
length++ // for the frame type
|
||||
}
|
||||
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase))
|
||||
return length
|
||||
}
|
||||
|
||||
// Write writes an CONNECTION_CLOSE frame.
|
||||
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
b.WriteByte(0x02)
|
||||
|
||||
if len(f.ReasonPhrase) > math.MaxUint16 {
|
||||
return errors.New("ConnectionFrame: ReasonPhrase too long")
|
||||
}
|
||||
|
||||
if version.UsesIETFFrameFormat() {
|
||||
utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode))
|
||||
utils.WriteVarInt(b, uint64(len(f.ReasonPhrase)))
|
||||
if f.IsApplicationError {
|
||||
b.WriteByte(0x1d)
|
||||
} else {
|
||||
utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode))
|
||||
utils.BigEndian.WriteUint16(b, uint16(len(f.ReasonPhrase)))
|
||||
b.WriteByte(0x1c)
|
||||
}
|
||||
b.WriteString(f.ReasonPhrase)
|
||||
|
||||
utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode))
|
||||
if !f.IsApplicationError {
|
||||
utils.WriteVarInt(b, 0)
|
||||
}
|
||||
utils.WriteVarInt(b, uint64(len(f.ReasonPhrase)))
|
||||
b.WriteString(f.ReasonPhrase)
|
||||
return nil
|
||||
}
|
||||
|
|
71
vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go
generated
vendored
Normal file
71
vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go
generated
vendored
Normal file
|
@ -0,0 +1,71 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A CryptoFrame is a CRYPTO frame
|
||||
type CryptoFrame struct {
|
||||
Offset protocol.ByteCount
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) {
|
||||
if _, err := r.ReadByte(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
frame := &CryptoFrame{}
|
||||
offset, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.Offset = protocol.ByteCount(offset)
|
||||
dataLen, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dataLen > uint64(r.Len()) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
if dataLen != 0 {
|
||||
frame.Data = make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(r, frame.Data); err != nil {
|
||||
// this should never happen, since we already checked the dataLen earlier
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *CryptoFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
||||
b.WriteByte(0x6)
|
||||
utils.WriteVarInt(b, uint64(f.Offset))
|
||||
utils.WriteVarInt(b, uint64(len(f.Data)))
|
||||
b.Write(f.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
return 1 + utils.VarIntLen(uint64(f.Offset)) + utils.VarIntLen(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data))
|
||||
}
|
||||
|
||||
// MaxDataLen returns the maximum data length
|
||||
func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount {
|
||||
// pretend that the data size will be 1 bytes
|
||||
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
|
||||
headerLen := 1 + utils.VarIntLen(uint64(f.Offset)) + 1
|
||||
if headerLen > maxSize {
|
||||
return 0
|
||||
}
|
||||
maxDataLen := maxSize - headerLen
|
||||
if utils.VarIntLen(uint64(maxDataLen)) != 1 {
|
||||
maxDataLen--
|
||||
}
|
||||
return maxDataLen
|
||||
}
|
38
vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go
generated
vendored
Normal file
38
vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go
generated
vendored
Normal file
|
@ -0,0 +1,38 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A DataBlockedFrame is a DATA_BLOCKED frame
|
||||
type DataBlockedFrame struct {
|
||||
DataLimit protocol.ByteCount
|
||||
}
|
||||
|
||||
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) {
|
||||
if _, err := r.ReadByte(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offset, err := utils.ReadVarInt(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DataBlockedFrame{
|
||||
DataLimit: protocol.ByteCount(offset),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *DataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||
typeByte := uint8(0x14)
|
||||
b.WriteByte(typeByte)
|
||||
utils.WriteVarInt(b, uint64(f.DataLimit))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||
return 1 + utils.VarIntLen(uint64(f.DataLimit))
|
||||
}
|
|
@ -2,16 +2,15 @@ package wire
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
)
|
||||
|
||||
// ParseNextFrame parses the next frame
|
||||
// It skips PADDING frames.
|
||||
func ParseNextFrame(r *bytes.Reader, hdr *Header, v protocol.VersionNumber) (Frame, error) {
|
||||
func ParseNextFrame(r *bytes.Reader, v protocol.VersionNumber) (Frame, error) {
|
||||
for r.Len() != 0 {
|
||||
typeByte, _ := r.ReadByte()
|
||||
if typeByte == 0x0 { // PADDING frame
|
||||
|
@ -19,154 +18,61 @@ func ParseNextFrame(r *bytes.Reader, hdr *Header, v protocol.VersionNumber) (Fra
|
|||
}
|
||||
r.UnreadByte()
|
||||
|
||||
if !v.UsesIETFFrameFormat() {
|
||||
return parseGQUICFrame(r, typeByte, hdr, v)
|
||||
}
|
||||
return parseIETFFrame(r, typeByte, v)
|
||||
return parseFrame(r, typeByte, v)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func parseIETFFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) {
|
||||
func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) {
|
||||
var frame Frame
|
||||
var err error
|
||||
if typeByte&0xf8 == 0x10 {
|
||||
if typeByte&0xf8 == 0x8 {
|
||||
frame, err = parseStreamFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidStreamData, err.Error())
|
||||
return nil, qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||
}
|
||||
return frame, err
|
||||
return frame, nil
|
||||
}
|
||||
// TODO: implement all IETF QUIC frame types
|
||||
switch typeByte {
|
||||
case 0x1:
|
||||
frame, err = parseRstStreamFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidRstStreamData, err.Error())
|
||||
}
|
||||
case 0x2:
|
||||
frame, err = parseConnectionCloseFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error())
|
||||
}
|
||||
case 0x4:
|
||||
frame, err = parseMaxDataFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
|
||||
}
|
||||
case 0x5:
|
||||
frame, err = parseMaxStreamDataFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
|
||||
}
|
||||
case 0x6:
|
||||
frame, err = parseMaxStreamIDFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||
}
|
||||
case 0x7:
|
||||
frame, err = parsePingFrame(r, v)
|
||||
case 0x8:
|
||||
frame, err = parseBlockedFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
|
||||
}
|
||||
case 0x9:
|
||||
frame, err = parseStreamBlockedFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
|
||||
}
|
||||
case 0xa:
|
||||
frame, err = parseStreamIDBlockedFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||
}
|
||||
case 0xc:
|
||||
case 0x2, 0x3:
|
||||
frame, err = parseAckFrame(r, v)
|
||||
case 0x4:
|
||||
frame, err = parseResetStreamFrame(r, v)
|
||||
case 0x5:
|
||||
frame, err = parseStopSendingFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||
}
|
||||
case 0xd:
|
||||
frame, err = parseAckFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidAckData, err.Error())
|
||||
}
|
||||
case 0xe:
|
||||
frame, err = parsePathChallengeFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||
}
|
||||
case 0xf:
|
||||
frame, err = parsePathResponseFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||
}
|
||||
case 0x1a:
|
||||
frame, err = parseAckEcnFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidAckData, err.Error())
|
||||
}
|
||||
default:
|
||||
err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte))
|
||||
}
|
||||
return frame, err
|
||||
}
|
||||
|
||||
func parseGQUICFrame(r *bytes.Reader, typeByte byte, hdr *Header, v protocol.VersionNumber) (Frame, error) {
|
||||
var frame Frame
|
||||
var err error
|
||||
if typeByte&0x80 == 0x80 {
|
||||
frame, err = parseStreamFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidStreamData, err.Error())
|
||||
}
|
||||
return frame, err
|
||||
} else if typeByte&0xc0 == 0x40 {
|
||||
frame, err = parseAckFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidAckData, err.Error())
|
||||
}
|
||||
return frame, err
|
||||
}
|
||||
switch typeByte {
|
||||
case 0x1:
|
||||
frame, err = parseRstStreamFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidRstStreamData, err.Error())
|
||||
}
|
||||
case 0x2:
|
||||
frame, err = parseConnectionCloseFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error())
|
||||
}
|
||||
case 0x3:
|
||||
frame, err = parseGoawayFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidGoawayData, err.Error())
|
||||
}
|
||||
case 0x4:
|
||||
frame, err = parseWindowUpdateFrame(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
|
||||
}
|
||||
case 0x5:
|
||||
frame, err = parseBlockedFrameLegacy(r, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
|
||||
}
|
||||
case 0x6:
|
||||
if !v.UsesStopWaitingFrames() {
|
||||
err = errors.New("STOP_WAITING frames not supported by this QUIC version")
|
||||
break
|
||||
}
|
||||
frame, err = parseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, v)
|
||||
if err != nil {
|
||||
err = qerr.Error(qerr.InvalidStopWaitingData, err.Error())
|
||||
}
|
||||
frame, err = parseCryptoFrame(r, v)
|
||||
case 0x7:
|
||||
frame, err = parsePingFrame(r, v)
|
||||
frame, err = parseNewTokenFrame(r, v)
|
||||
case 0x10:
|
||||
frame, err = parseMaxDataFrame(r, v)
|
||||
case 0x11:
|
||||
frame, err = parseMaxStreamDataFrame(r, v)
|
||||
case 0x12, 0x13:
|
||||
frame, err = parseMaxStreamsFrame(r, v)
|
||||
case 0x14:
|
||||
frame, err = parseDataBlockedFrame(r, v)
|
||||
case 0x15:
|
||||
frame, err = parseStreamDataBlockedFrame(r, v)
|
||||
case 0x16, 0x17:
|
||||
frame, err = parseStreamsBlockedFrame(r, v)
|
||||
case 0x18:
|
||||
frame, err = parseNewConnectionIDFrame(r, v)
|
||||
case 0x19:
|
||||
frame, err = parseRetireConnectionIDFrame(r, v)
|
||||
case 0x1a:
|
||||
frame, err = parsePathChallengeFrame(r, v)
|
||||
case 0x1b:
|
||||
frame, err = parsePathResponseFrame(r, v)
|
||||
case 0x1c, 0x1d:
|
||||
frame, err = parseConnectionCloseFrame(r, v)
|
||||
default:
|
||||
err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte))
|
||||
err = fmt.Errorf("unknown type byte 0x%x", typeByte)
|
||||
}
|
||||
return frame, err
|
||||
if err != nil {
|
||||
return nil, qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
|
|
@ -1,68 +0,0 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// A GoawayFrame is a GOAWAY frame
|
||||
type GoawayFrame struct {
|
||||
ErrorCode qerr.ErrorCode
|
||||
LastGoodStream protocol.StreamID
|
||||
ReasonPhrase string
|
||||
}
|
||||
|
||||
// parseGoawayFrame parses a GOAWAY frame
|
||||
func parseGoawayFrame(r *bytes.Reader, _ protocol.VersionNumber) (*GoawayFrame, error) {
|
||||
frame := &GoawayFrame{}
|
||||
|
||||
if _, err := r.ReadByte(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errorCode, err := utils.BigEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.ErrorCode = qerr.ErrorCode(errorCode)
|
||||
|
||||
lastGoodStream, err := utils.BigEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.LastGoodStream = protocol.StreamID(lastGoodStream)
|
||||
|
||||
reasonPhraseLen, err := utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if reasonPhraseLen > uint16(protocol.MaxReceivePacketSize) {
|
||||
return nil, qerr.Error(qerr.InvalidGoawayData, "reason phrase too long")
|
||||
}
|
||||
|
||||
reasonPhrase := make([]byte, reasonPhraseLen)
|
||||
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame.ReasonPhrase = string(reasonPhrase)
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *GoawayFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
||||
b.WriteByte(0x03)
|
||||
utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode))
|
||||
utils.BigEndian.WriteUint32(b, uint32(f.LastGoodStream))
|
||||
utils.BigEndian.WriteUint16(b, uint16(len(f.ReasonPhrase)))
|
||||
b.WriteString(f.ReasonPhrase)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *GoawayFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||
return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase))
|
||||
}
|
|
@ -3,7 +3,6 @@ package wire
|
|||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -11,10 +10,7 @@ import (
|
|||
)
|
||||
|
||||
// Header is the header of a QUIC packet.
|
||||
// It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header.
|
||||
type Header struct {
|
||||
IsPublicHeader bool
|
||||
|
||||
Raw []byte
|
||||
|
||||
Version protocol.VersionNumber
|
||||
|
@ -29,12 +25,6 @@ type Header struct {
|
|||
IsVersionNegotiation bool
|
||||
SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
|
||||
|
||||
// only needed for the gQUIC Public Header
|
||||
VersionFlag bool
|
||||
ResetFlag bool
|
||||
DiversificationNonce []byte
|
||||
|
||||
// only needed for the IETF Header
|
||||
Type protocol.PacketType
|
||||
IsLongHeader bool
|
||||
KeyPhase int
|
||||
|
@ -42,15 +32,8 @@ type Header struct {
|
|||
Token []byte
|
||||
}
|
||||
|
||||
var errInvalidPacketNumberLen = errors.New("invalid packet number length")
|
||||
|
||||
// Write writes the Header.
|
||||
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
|
||||
if !ver.UsesIETFHeaderFormat() {
|
||||
h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
|
||||
return h.writePublicHeader(b, pers, ver)
|
||||
}
|
||||
// write an IETF QUIC header
|
||||
if h.IsLongHeader {
|
||||
return h.writeLongHeader(b, ver)
|
||||
}
|
||||
|
@ -69,7 +52,7 @@ func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) erro
|
|||
b.Write(h.DestConnectionID.Bytes())
|
||||
b.Write(h.SrcConnectionID.Bytes())
|
||||
|
||||
if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
utils.WriteVarInt(b, uint64(len(h.Token)))
|
||||
b.Write(h.Token)
|
||||
}
|
||||
|
@ -89,176 +72,36 @@ func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) erro
|
|||
return nil
|
||||
}
|
||||
|
||||
if v.UsesLengthInHeader() {
|
||||
utils.WriteVarInt(b, uint64(h.PayloadLen))
|
||||
}
|
||||
if v.UsesVarintPacketNumbers() {
|
||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||
}
|
||||
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
|
||||
if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
|
||||
if len(h.DiversificationNonce) != 32 {
|
||||
return errors.New("invalid diversification nonce length")
|
||||
}
|
||||
b.Write(h.DiversificationNonce)
|
||||
}
|
||||
return nil
|
||||
utils.WriteVarInt(b, uint64(h.PayloadLen))
|
||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||
}
|
||||
|
||||
func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
||||
typeByte := byte(0x30)
|
||||
typeByte |= byte(h.KeyPhase << 6)
|
||||
if !v.UsesVarintPacketNumbers() {
|
||||
switch h.PacketNumberLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
case protocol.PacketNumberLen2:
|
||||
typeByte |= 0x1
|
||||
case protocol.PacketNumberLen4:
|
||||
typeByte |= 0x2
|
||||
default:
|
||||
return errInvalidPacketNumberLen
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteByte(typeByte)
|
||||
b.Write(h.DestConnectionID.Bytes())
|
||||
|
||||
if !v.UsesVarintPacketNumbers() {
|
||||
switch h.PacketNumberLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
b.WriteByte(uint8(h.PacketNumber))
|
||||
case protocol.PacketNumberLen2:
|
||||
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
|
||||
case protocol.PacketNumberLen4:
|
||||
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||
}
|
||||
|
||||
// writePublicHeader writes a Public Header.
|
||||
func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
|
||||
if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) {
|
||||
return errors.New("PublicHeader: Can only write regular packets")
|
||||
}
|
||||
if h.SrcConnectionID.Len() != 0 {
|
||||
return errors.New("PublicHeader: SrcConnectionID must not be set")
|
||||
}
|
||||
if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
|
||||
return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
|
||||
}
|
||||
|
||||
publicFlagByte := uint8(0x00)
|
||||
if h.VersionFlag {
|
||||
publicFlagByte |= 0x01
|
||||
}
|
||||
if h.DestConnectionID.Len() > 0 {
|
||||
publicFlagByte |= 0x08
|
||||
}
|
||||
if len(h.DiversificationNonce) > 0 {
|
||||
if len(h.DiversificationNonce) != 32 {
|
||||
return errors.New("invalid diversification nonce length")
|
||||
}
|
||||
publicFlagByte |= 0x04
|
||||
}
|
||||
switch h.PacketNumberLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
publicFlagByte |= 0x00
|
||||
case protocol.PacketNumberLen2:
|
||||
publicFlagByte |= 0x10
|
||||
case protocol.PacketNumberLen4:
|
||||
publicFlagByte |= 0x20
|
||||
}
|
||||
b.WriteByte(publicFlagByte)
|
||||
|
||||
if h.DestConnectionID.Len() > 0 {
|
||||
b.Write(h.DestConnectionID)
|
||||
}
|
||||
if h.VersionFlag && pers == protocol.PerspectiveClient {
|
||||
utils.BigEndian.WriteUint32(b, uint32(h.Version))
|
||||
}
|
||||
if len(h.DiversificationNonce) > 0 {
|
||||
b.Write(h.DiversificationNonce)
|
||||
}
|
||||
|
||||
switch h.PacketNumberLen {
|
||||
case protocol.PacketNumberLen1:
|
||||
b.WriteByte(uint8(h.PacketNumber))
|
||||
case protocol.PacketNumberLen2:
|
||||
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
|
||||
case protocol.PacketNumberLen4:
|
||||
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
|
||||
case protocol.PacketNumberLen6:
|
||||
return errInvalidPacketNumberLen
|
||||
default:
|
||||
return errors.New("PublicHeader: PacketNumberLen not set")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLength determines the length of the Header.
|
||||
func (h *Header) GetLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
if !v.UsesIETFHeaderFormat() {
|
||||
return h.getPublicHeaderLength()
|
||||
}
|
||||
return h.getHeaderLength(v)
|
||||
}
|
||||
|
||||
func (h *Header) getHeaderLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
|
||||
func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount {
|
||||
if h.IsLongHeader {
|
||||
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen)
|
||||
if v.UsesLengthInHeader() {
|
||||
length += utils.VarIntLen(uint64(h.PayloadLen))
|
||||
}
|
||||
if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
|
||||
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.PayloadLen))
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
||||
}
|
||||
if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
|
||||
length += protocol.ByteCount(len(h.DiversificationNonce))
|
||||
}
|
||||
return length, nil
|
||||
return length
|
||||
}
|
||||
|
||||
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
|
||||
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
|
||||
return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
|
||||
}
|
||||
length += protocol.ByteCount(h.PacketNumberLen)
|
||||
return length, nil
|
||||
}
|
||||
|
||||
// getPublicHeaderLength gets the length of the publicHeader in bytes.
|
||||
// It can only be called for regular packets.
|
||||
func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
|
||||
length := protocol.ByteCount(1) // 1 byte for public flags
|
||||
if h.PacketNumberLen == protocol.PacketNumberLen6 {
|
||||
return 0, errInvalidPacketNumberLen
|
||||
}
|
||||
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
|
||||
return 0, errPacketNumberLenNotSet
|
||||
}
|
||||
length += protocol.ByteCount(h.PacketNumberLen)
|
||||
length += protocol.ByteCount(h.DestConnectionID.Len())
|
||||
// Version Number in packets sent by the client
|
||||
if h.VersionFlag {
|
||||
length += 4
|
||||
}
|
||||
length += protocol.ByteCount(len(h.DiversificationNonce))
|
||||
return length, nil
|
||||
return length
|
||||
}
|
||||
|
||||
// Log logs the Header
|
||||
func (h *Header) Log(logger utils.Logger) {
|
||||
if h.IsPublicHeader {
|
||||
h.logPublicHeader(logger)
|
||||
} else {
|
||||
h.logHeader(logger)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Header) logHeader(logger utils.Logger) {
|
||||
if h.IsLongHeader {
|
||||
if h.Version == 0 {
|
||||
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
|
||||
|
@ -275,14 +118,6 @@ func (h *Header) logHeader(logger utils.Logger) {
|
|||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
|
||||
return
|
||||
}
|
||||
if h.Version == protocol.Version44 {
|
||||
var divNonce string
|
||||
if h.Type == protocol.PacketType0RTT {
|
||||
divNonce = fmt.Sprintf("Diversification Nonce: %#x, ", h.DiversificationNonce)
|
||||
}
|
||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, divNonce, h.Version)
|
||||
return
|
||||
}
|
||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
|
||||
}
|
||||
} else {
|
||||
|
@ -290,14 +125,6 @@ func (h *Header) logHeader(logger utils.Logger) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Header) logPublicHeader(logger utils.Logger) {
|
||||
ver := "(unset)"
|
||||
if h.Version != 0 {
|
||||
ver = h.Version.String()
|
||||
}
|
||||
logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
|
||||
}
|
||||
|
||||
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
|
||||
dcil, err := encodeSingleConnIDLen(dest)
|
||||
if err != nil {
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// The InvariantHeader is the version independent part of the header
|
||||
|
@ -32,22 +32,10 @@ func ParseInvariantHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Invariant
|
|||
|
||||
// If this is not a Long Header, it could either be a Public Header or a Short Header.
|
||||
if !h.IsLongHeader {
|
||||
// In the Public Header 0x8 is the Connection ID Flag.
|
||||
// In the IETF Short Header:
|
||||
// * 0x8 it is the gQUIC Demultiplexing bit, and always 0.
|
||||
// * 0x20 and 0x10 are always 1.
|
||||
var connIDLen int
|
||||
if typeByte&0x8 > 0 { // Public Header containing a connection ID
|
||||
connIDLen = 8
|
||||
}
|
||||
if typeByte&0x38 == 0x30 { // Short Header
|
||||
connIDLen = shortHeaderConnIDLen
|
||||
}
|
||||
if connIDLen > 0 {
|
||||
h.DestConnectionID, err = protocol.ReadConnectionID(b, connIDLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var err error
|
||||
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
@ -81,15 +69,6 @@ func (iv *InvariantHeader) Parse(b *bytes.Reader, sentBy protocol.Perspective, v
|
|||
}
|
||||
return iv.parseLongHeader(b, sentBy, ver)
|
||||
}
|
||||
// The Public Header never uses 6 byte packet numbers.
|
||||
// Therefore, the third and fourth bit will never be 11.
|
||||
// For the Short Header, the third and fourth bit are always 11.
|
||||
if iv.typeByte&0x30 != 0x30 {
|
||||
if sentBy == protocol.PerspectiveServer && iv.typeByte&0x1 > 0 {
|
||||
return iv.parseVersionNegotiationPacket(b)
|
||||
}
|
||||
return iv.parsePublicHeader(b, sentBy, ver)
|
||||
}
|
||||
return iv.parseShortHeader(b, ver)
|
||||
}
|
||||
|
||||
|
@ -104,7 +83,6 @@ func (iv *InvariantHeader) toHeader() *Header {
|
|||
|
||||
func (iv *InvariantHeader) parseVersionNegotiationPacket(b *bytes.Reader) (*Header, error) {
|
||||
h := iv.toHeader()
|
||||
h.VersionFlag = true
|
||||
if b.Len() == 0 {
|
||||
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
|
||||
}
|
||||
|
@ -145,7 +123,7 @@ func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader, sentBy protocol.Pers
|
|||
return h, nil
|
||||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
tokenLen, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -159,37 +137,17 @@ func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader, sentBy protocol.Pers
|
|||
}
|
||||
}
|
||||
|
||||
if v.UsesLengthInHeader() {
|
||||
pl, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.PayloadLen = protocol.ByteCount(pl)
|
||||
pl, err := utils.ReadVarInt(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v.UsesVarintPacketNumbers() {
|
||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.PacketNumber = pn
|
||||
h.PacketNumberLen = pnLen
|
||||
} else {
|
||||
pn, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.PacketNumber = protocol.PacketNumber(pn)
|
||||
h.PacketNumberLen = protocol.PacketNumberLen4
|
||||
}
|
||||
if h.Type == protocol.PacketType0RTT && v == protocol.Version44 && sentBy == protocol.PerspectiveServer {
|
||||
h.DiversificationNonce = make([]byte, 32)
|
||||
if _, err := io.ReadFull(b, h.DiversificationNonce); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
h.PayloadLen = protocol.ByteCount(pl)
|
||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.PacketNumber = pn
|
||||
h.PacketNumberLen = pnLen
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
@ -198,76 +156,12 @@ func (iv *InvariantHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionN
|
|||
h := iv.toHeader()
|
||||
h.KeyPhase = int(iv.typeByte&0x40) >> 6
|
||||
|
||||
if v.UsesVarintPacketNumbers() {
|
||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.PacketNumber = pn
|
||||
h.PacketNumberLen = pnLen
|
||||
} else {
|
||||
switch iv.typeByte & 0x3 {
|
||||
case 0x0:
|
||||
h.PacketNumberLen = protocol.PacketNumberLen1
|
||||
case 0x1:
|
||||
h.PacketNumberLen = protocol.PacketNumberLen2
|
||||
case 0x2:
|
||||
h.PacketNumberLen = protocol.PacketNumberLen4
|
||||
default:
|
||||
return nil, errInvalidPacketNumberLen
|
||||
}
|
||||
p, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.PacketNumber = protocol.PacketNumber(p)
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (iv *InvariantHeader) parsePublicHeader(b *bytes.Reader, sentBy protocol.Perspective, ver protocol.VersionNumber) (*Header, error) {
|
||||
h := iv.toHeader()
|
||||
h.IsPublicHeader = true
|
||||
h.ResetFlag = iv.typeByte&0x2 > 0
|
||||
if h.ResetFlag {
|
||||
return h, nil
|
||||
}
|
||||
|
||||
h.VersionFlag = iv.typeByte&0x1 > 0
|
||||
if h.VersionFlag && sentBy == protocol.PerspectiveClient {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.Version = protocol.VersionNumber(v)
|
||||
}
|
||||
|
||||
// Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server.
|
||||
// It doesn't have any meaning when sent by the client.
|
||||
if sentBy == protocol.PerspectiveServer && iv.typeByte&0x4 > 0 {
|
||||
h.DiversificationNonce = make([]byte, 32)
|
||||
if _, err := io.ReadFull(b, h.DiversificationNonce); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch iv.typeByte & 0x30 {
|
||||
case 0x00:
|
||||
h.PacketNumberLen = protocol.PacketNumberLen1
|
||||
case 0x10:
|
||||
h.PacketNumberLen = protocol.PacketNumberLen2
|
||||
case 0x20:
|
||||
h.PacketNumberLen = protocol.PacketNumberLen4
|
||||
}
|
||||
|
||||
pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
|
||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.PacketNumber = protocol.PacketNumber(pn)
|
||||
h.PacketNumber = pn
|
||||
h.PacketNumberLen = pnLen
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -17,14 +18,11 @@ func LogFrame(logger utils.Logger, frame Frame, sent bool) {
|
|||
dir = "->"
|
||||
}
|
||||
switch f := frame.(type) {
|
||||
case *CryptoFrame:
|
||||
dataLen := protocol.ByteCount(len(f.Data))
|
||||
logger.Debugf("\t%s &wire.CryptoFrame{Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.Offset, dataLen, f.Offset+dataLen)
|
||||
case *StreamFrame:
|
||||
logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
|
||||
case *StopWaitingFrame:
|
||||
if sent {
|
||||
logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
|
||||
} else {
|
||||
logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
|
||||
}
|
||||
case *AckFrame:
|
||||
if len(f.AckRanges) > 1 {
|
||||
ackRanges := make([]string, len(f.AckRanges))
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue