mirror of https://github.com/v2ray/v2ray-core
parent
6870ead73e
commit
19926b8e4f
|
@ -5,37 +5,85 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"v2ray.com/core/transport/internet/tls"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
quic "github.com/lucas-clemente/quic-go"
|
||||||
|
|
||||||
"v2ray.com/core/common"
|
"v2ray.com/core/common"
|
||||||
"v2ray.com/core/common/net"
|
"v2ray.com/core/common/net"
|
||||||
"v2ray.com/core/transport/internet"
|
"v2ray.com/core/transport/internet"
|
||||||
|
"v2ray.com/core/transport/internet/tls"
|
||||||
)
|
)
|
||||||
|
|
||||||
type clientSessions struct {
|
type clientSessions struct {
|
||||||
access sync.Mutex
|
access sync.Mutex
|
||||||
sessions map[net.Destination]quic.Session
|
sessions map[net.Destination][]quic.Session
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *clientSessions) getSession(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (quic.Session, error) {
|
func removeInactiveSessions(sessions []quic.Session) []quic.Session {
|
||||||
|
lastActive := 0
|
||||||
|
for _, s := range sessions {
|
||||||
|
active := true
|
||||||
|
select {
|
||||||
|
case <-s.Context().Done():
|
||||||
|
active = false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if active {
|
||||||
|
sessions[lastActive] = s
|
||||||
|
lastActive++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastActive < len(sessions) {
|
||||||
|
for i := 0; i < len(sessions); i++ {
|
||||||
|
sessions[i] = nil
|
||||||
|
}
|
||||||
|
sessions = sessions[:lastActive]
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessions
|
||||||
|
}
|
||||||
|
|
||||||
|
func openStream(sessions []quic.Session) (quic.Stream, net.Addr, error) {
|
||||||
|
for _, s := range sessions {
|
||||||
|
stream, err := s.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
newError("failed to create stream").Base(err).WriteToLog()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return stream, s.LocalAddr(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
|
||||||
s.access.Lock()
|
s.access.Lock()
|
||||||
defer s.access.Unlock()
|
defer s.access.Unlock()
|
||||||
|
|
||||||
if s.sessions == nil {
|
if s.sessions == nil {
|
||||||
s.sessions = make(map[net.Destination]quic.Session)
|
s.sessions = make(map[net.Destination][]quic.Session)
|
||||||
}
|
}
|
||||||
|
|
||||||
dest := net.DestinationFromAddr(destAddr)
|
dest := net.DestinationFromAddr(destAddr)
|
||||||
|
|
||||||
if session, found := s.sessions[dest]; found {
|
var sessions []quic.Session
|
||||||
select {
|
if s, found := s.sessions[dest]; found {
|
||||||
case <-session.Context().Done():
|
sessions = s
|
||||||
// Session has been closed. Creating a new one.
|
|
||||||
default:
|
|
||||||
return session, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessions = removeInactiveSessions(sessions)
|
||||||
|
s.sessions[dest] = sessions
|
||||||
|
|
||||||
|
stream, local, err := openStream(sessions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if stream != nil {
|
||||||
|
return &interConn{
|
||||||
|
stream: stream,
|
||||||
|
local: local,
|
||||||
|
remote: destAddr,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
|
rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
|
||||||
|
@ -47,13 +95,12 @@ func (s *clientSessions) getSession(destAddr net.Addr, config *Config, tlsConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
quicConfig := &quic.Config{
|
quicConfig := &quic.Config{
|
||||||
Versions: []quic.VersionNumber{quic.VersionMilestone0_10_0},
|
|
||||||
ConnectionIDLength: 12,
|
ConnectionIDLength: 12,
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
HandshakeTimeout: time.Second * 4,
|
HandshakeTimeout: time.Second * 4,
|
||||||
IdleTimeout: time.Second * 300,
|
IdleTimeout: time.Second * 60,
|
||||||
MaxReceiveStreamFlowControlWindow: 128 * 1024,
|
MaxReceiveStreamFlowControlWindow: 256 * 1024,
|
||||||
MaxReceiveConnectionFlowControlWindow: 512 * 1024,
|
MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
|
||||||
MaxIncomingUniStreams: -1,
|
MaxIncomingUniStreams: -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,8 +116,16 @@ func (s *clientSessions) getSession(destAddr net.Addr, config *Config, tlsConfig
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sessions[dest] = session
|
s.sessions[dest] = append(sessions, session)
|
||||||
return session, nil
|
stream, err = session.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &interConn{
|
||||||
|
stream: stream,
|
||||||
|
local: session.LocalAddr(),
|
||||||
|
remote: destAddr,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var client clientSessions
|
var client clientSessions
|
||||||
|
@ -91,21 +146,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
|
||||||
|
|
||||||
config := streamSettings.ProtocolSettings.(*Config)
|
config := streamSettings.ProtocolSettings.(*Config)
|
||||||
|
|
||||||
session, err := client.getSession(destAddr, config, tlsConfig, streamSettings.SocketSettings)
|
return client.openConnection(destAddr, config, tlsConfig, streamSettings.SocketSettings)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := session.OpenStreamSync()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &interConn{
|
|
||||||
stream: conn,
|
|
||||||
local: session.LocalAddr(),
|
|
||||||
remote: destAddr,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
@ -84,14 +84,13 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
|
||||||
}
|
}
|
||||||
|
|
||||||
quicConfig := &quic.Config{
|
quicConfig := &quic.Config{
|
||||||
Versions: []quic.VersionNumber{quic.VersionMilestone0_10_0},
|
|
||||||
ConnectionIDLength: 12,
|
ConnectionIDLength: 12,
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
HandshakeTimeout: time.Second * 4,
|
HandshakeTimeout: time.Second * 4,
|
||||||
IdleTimeout: time.Second * 300,
|
IdleTimeout: time.Second * 60,
|
||||||
MaxReceiveStreamFlowControlWindow: 128 * 1024,
|
MaxReceiveStreamFlowControlWindow: 256 * 1024,
|
||||||
MaxReceiveConnectionFlowControlWindow: 512 * 1024,
|
MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
|
||||||
MaxIncomingStreams: 256,
|
MaxIncomingStreams: 64,
|
||||||
MaxIncomingUniStreams: -1,
|
MaxIncomingUniStreams: -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,12 +8,19 @@
|
||||||
[](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
|
[](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
|
||||||
[](https://codecov.io/gh/lucas-clemente/quic-go/)
|
[](https://codecov.io/gh/lucas-clemente/quic-go/)
|
||||||
|
|
||||||
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go.
|
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go. It roughly implements the [IETF QUIC draft](https://github.com/quicwg/base-drafts), although we don't fully support any of the draft versions at the moment.
|
||||||
|
|
||||||
## Roadmap
|
## Version compatibility
|
||||||
|
|
||||||
quic-go is compatible with the current version(s) of Google Chrome and QUIC as deployed on Google's servers. We're actively tracking the development of the Chrome code to ensure compatibility as the protocol evolves. In that process, we're dropping support for old QUIC versions.
|
Since quic-go is under active development, there's no guarantee that two builds of different commits are interoperable. The QUIC version used in the *master* branch is just a placeholder, and should not be considered stable.
|
||||||
As Google's QUIC versions are expected to converge towards the [IETF QUIC draft](https://github.com/quicwg/base-drafts), quic-go will eventually implement that draft.
|
|
||||||
|
If you want to use quic-go as a library in other projects, please consider using a [tagged release](https://github.com/lucas-clemente/quic-go/releases). These releases expose [experimental QUIC versions](https://github.com/quicwg/base-drafts/wiki/QUIC-Versions), which are guaranteed to be stable.
|
||||||
|
|
||||||
|
## Google QUIC
|
||||||
|
|
||||||
|
quic-go used to support both the QUIC versions supported by Google Chrome and QUIC as deployed on Google's servers, as well as IETF QUIC. Due to the divergence of the two protocols, we decided to not support both versions any more.
|
||||||
|
|
||||||
|
The *master* branch **only** supports IETF QUIC. For Google QUIC support, please refer to the [gquic branch](https://github.com/lucas-clemente/quic-go/tree/gquic).
|
||||||
|
|
||||||
## Guides
|
## Guides
|
||||||
|
|
||||||
|
@ -27,31 +34,19 @@ Running tests:
|
||||||
|
|
||||||
go test ./...
|
go test ./...
|
||||||
|
|
||||||
### Running the example server
|
### HTTP mapping
|
||||||
|
|
||||||
go run example/main.go -www /var/www/
|
We're currently not implementing the HTTP mapping as described in the [QUIC over HTTP draft](https://quicwg.org/base-drafts/draft-ietf-quic-http.html). The HTTP mapping here is a leftover from Google QUIC.
|
||||||
|
|
||||||
Using the `quic_client` from chromium:
|
|
||||||
|
|
||||||
quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
|
|
||||||
|
|
||||||
Using Chrome:
|
|
||||||
|
|
||||||
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
|
|
||||||
|
|
||||||
### QUIC without HTTP/2
|
### QUIC without HTTP/2
|
||||||
|
|
||||||
Take a look at [this echo example](example/echo/echo.go).
|
Take a look at [this echo example](example/echo/echo.go).
|
||||||
|
|
||||||
### Using the example client
|
|
||||||
|
|
||||||
go run example/client/main.go https://clemente.io
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### As a server
|
### As a server
|
||||||
|
|
||||||
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
|
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
|
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
|
||||||
|
|
|
@ -10,6 +10,9 @@ environment:
|
||||||
- GOARCH: 386
|
- GOARCH: 386
|
||||||
- GOARCH: amd64
|
- GOARCH: amd64
|
||||||
|
|
||||||
|
hosts:
|
||||||
|
quic.clemente.io: 127.0.0.1
|
||||||
|
|
||||||
clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
|
clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
|
||||||
|
|
||||||
install:
|
install:
|
||||||
|
@ -19,7 +22,6 @@ install:
|
||||||
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
|
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
|
||||||
- echo %PATH%
|
- echo %PATH%
|
||||||
- echo %GOPATH%
|
- echo %GOPATH%
|
||||||
- git submodule update --init --recursive
|
|
||||||
- go get github.com/onsi/ginkgo/ginkgo
|
- go get github.com/onsi/ginkgo/ginkgo
|
||||||
- go get github.com/onsi/gomega
|
- go get github.com/onsi/gomega
|
||||||
- go version
|
- go version
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -9,12 +8,11 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type client struct {
|
type client struct {
|
||||||
|
@ -25,29 +23,25 @@ type client struct {
|
||||||
// If it is started with Dial, we take a packet conn as a parameter.
|
// If it is started with Dial, we take a packet conn as a parameter.
|
||||||
createdPacketConn bool
|
createdPacketConn bool
|
||||||
|
|
||||||
hostname string
|
|
||||||
|
|
||||||
packetHandlers packetHandlerManager
|
packetHandlers packetHandlerManager
|
||||||
|
|
||||||
token []byte
|
token []byte
|
||||||
numRetries int
|
|
||||||
|
|
||||||
versionNegotiated bool // has the server accepted our version
|
versionNegotiated bool // has the server accepted our version
|
||||||
receivedVersionNegotiationPacket bool
|
receivedVersionNegotiationPacket bool
|
||||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||||
|
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
mintConf *mint.Config
|
|
||||||
config *Config
|
config *Config
|
||||||
|
|
||||||
srcConnID protocol.ConnectionID
|
srcConnID protocol.ConnectionID
|
||||||
destConnID protocol.ConnectionID
|
destConnID protocol.ConnectionID
|
||||||
|
origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry)
|
||||||
|
|
||||||
initialVersion protocol.VersionNumber
|
initialVersion protocol.VersionNumber
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
|
|
||||||
handshakeChan chan struct{}
|
handshakeChan chan struct{}
|
||||||
closeCallback func(protocol.ConnectionID)
|
|
||||||
|
|
||||||
session quicSession
|
session quicSession
|
||||||
|
|
||||||
|
@ -128,18 +122,11 @@ func dialContext(
|
||||||
createdPacketConn bool,
|
createdPacketConn bool,
|
||||||
) (Session, error) {
|
) (Session, error) {
|
||||||
config = populateClientConfig(config, createdPacketConn)
|
config = populateClientConfig(config, createdPacketConn)
|
||||||
if !createdPacketConn {
|
|
||||||
for _, v := range config.Versions {
|
|
||||||
if v == protocol.Version44 {
|
|
||||||
return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
|
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
|
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, createdPacketConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -156,16 +143,14 @@ func newClient(
|
||||||
config *Config,
|
config *Config,
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
host string,
|
host string,
|
||||||
closeCallback func(protocol.ConnectionID),
|
|
||||||
createdPacketConn bool,
|
createdPacketConn bool,
|
||||||
) (*client, error) {
|
) (*client, error) {
|
||||||
var hostname string
|
if tlsConf == nil {
|
||||||
if tlsConf != nil {
|
tlsConf = &tls.Config{}
|
||||||
hostname = tlsConf.ServerName
|
|
||||||
}
|
}
|
||||||
if hostname == "" {
|
if tlsConf.ServerName == "" {
|
||||||
var err error
|
var err error
|
||||||
hostname, _, err = net.SplitHostPort(host)
|
tlsConf.ServerName, _, err = net.SplitHostPort(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -179,19 +164,13 @@ func newClient(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
onClose := func(protocol.ConnectionID) {}
|
|
||||||
if closeCallback != nil {
|
|
||||||
onClose = closeCallback
|
|
||||||
}
|
|
||||||
c := &client{
|
c := &client{
|
||||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||||
createdPacketConn: createdPacketConn,
|
createdPacketConn: createdPacketConn,
|
||||||
hostname: hostname,
|
|
||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
config: config,
|
config: config,
|
||||||
version: config.Versions[0],
|
version: config.Versions[0],
|
||||||
handshakeChan: make(chan struct{}),
|
handshakeChan: make(chan struct{}),
|
||||||
closeCallback: onClose,
|
|
||||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||||
}
|
}
|
||||||
return c, c.generateConnectionIDs()
|
return c, c.generateConnectionIDs()
|
||||||
|
@ -219,11 +198,11 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||||
|
|
||||||
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
||||||
if maxReceiveStreamFlowControlWindow == 0 {
|
if maxReceiveStreamFlowControlWindow == 0 {
|
||||||
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
|
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
|
||||||
}
|
}
|
||||||
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
|
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
|
||||||
if maxReceiveConnectionFlowControlWindow == 0 {
|
if maxReceiveConnectionFlowControlWindow == 0 {
|
||||||
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
|
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
|
||||||
}
|
}
|
||||||
maxIncomingStreams := config.MaxIncomingStreams
|
maxIncomingStreams := config.MaxIncomingStreams
|
||||||
if maxIncomingStreams == 0 {
|
if maxIncomingStreams == 0 {
|
||||||
|
@ -241,17 +220,11 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||||
if connIDLen == 0 && !createdPacketConn {
|
if connIDLen == 0 && !createdPacketConn {
|
||||||
connIDLen = protocol.DefaultConnectionIDLength
|
connIDLen = protocol.DefaultConnectionIDLength
|
||||||
}
|
}
|
||||||
for _, v := range versions {
|
|
||||||
if v == protocol.Version44 {
|
|
||||||
connIDLen = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
Versions: versions,
|
Versions: versions,
|
||||||
HandshakeTimeout: handshakeTimeout,
|
HandshakeTimeout: handshakeTimeout,
|
||||||
IdleTimeout: idleTimeout,
|
IdleTimeout: idleTimeout,
|
||||||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
|
||||||
ConnectionIDLength: connIDLen,
|
ConnectionIDLength: connIDLen,
|
||||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||||
|
@ -262,75 +235,26 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) generateConnectionIDs() error {
|
func (c *client) generateConnectionIDs() error {
|
||||||
connIDLen := protocol.ConnectionIDLenGQUIC
|
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
|
||||||
if c.version.UsesTLS() {
|
|
||||||
connIDLen = c.config.ConnectionIDLength
|
|
||||||
}
|
|
||||||
srcConnID, err := generateConnectionID(connIDLen)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
destConnID := srcConnID
|
destConnID, err := generateConnectionIDForInitial()
|
||||||
if c.version.UsesTLS() {
|
|
||||||
destConnID, err = generateConnectionIDForInitial()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
c.srcConnID = srcConnID
|
c.srcConnID = srcConnID
|
||||||
c.destConnID = destConnID
|
c.destConnID = destConnID
|
||||||
if c.version == protocol.Version44 {
|
|
||||||
c.srcConnID = nil
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) dial(ctx context.Context) error {
|
func (c *client) dial(ctx context.Context) error {
|
||||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||||
|
|
||||||
var err error
|
if err := c.createNewTLSSession(c.version); err != nil {
|
||||||
if c.version.UsesTLS() {
|
|
||||||
err = c.dialTLS(ctx)
|
|
||||||
} else {
|
|
||||||
err = c.dialGQUIC(ctx)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) dialGQUIC(ctx context.Context) error {
|
|
||||||
if err := c.createNewGQUICSession(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err := c.establishSecureConnection(ctx)
|
err := c.establishSecureConnection(ctx)
|
||||||
if err == errCloseSessionForNewVersion {
|
|
||||||
return c.dial(ctx)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) dialTLS(ctx context.Context) error {
|
|
||||||
params := &handshake.TransportParameters{
|
|
||||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
|
||||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
|
||||||
IdleTimeout: c.config.IdleTimeout,
|
|
||||||
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
|
||||||
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
|
|
||||||
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
|
|
||||||
DisableMigration: true,
|
|
||||||
}
|
|
||||||
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
|
|
||||||
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
mintConf.ExtensionHandler = extHandler
|
|
||||||
mintConf.ServerName = c.hostname
|
|
||||||
c.mintConf = mintConf
|
|
||||||
|
|
||||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = c.establishSecureConnection(ctx)
|
|
||||||
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
|
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
|
||||||
return c.dial(ctx)
|
return c.dial(ctx)
|
||||||
}
|
}
|
||||||
|
@ -340,9 +264,9 @@ func (c *client) dialTLS(ctx context.Context) error {
|
||||||
// establishSecureConnection runs the session, and tries to establish a secure connection
|
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||||
// It returns:
|
// It returns:
|
||||||
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
|
||||||
// - any other error that might occur
|
// - any other error that might occur
|
||||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
// - when the connection is forward-secure
|
||||||
func (c *client) establishSecureConnection(ctx context.Context) error {
|
func (c *client) establishSecureConnection(ctx context.Context) error {
|
||||||
errorChan := make(chan error, 1)
|
errorChan := make(chan error, 1)
|
||||||
|
|
||||||
|
@ -387,35 +311,14 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.version.UsesIETFHeaderFormat() {
|
|
||||||
connID := p.header.DestConnectionID
|
|
||||||
// reject packets with truncated connection id if we didn't request truncation
|
|
||||||
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
|
|
||||||
return errors.New("received packet with truncated connection ID, but didn't request truncation")
|
|
||||||
}
|
|
||||||
// reject packets with the wrong connection ID
|
|
||||||
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
|
|
||||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
|
|
||||||
}
|
|
||||||
if p.header.ResetFlag {
|
|
||||||
return c.handlePublicReset(p)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// reject packets with the wrong connection ID
|
// reject packets with the wrong connection ID
|
||||||
if !p.header.DestConnectionID.Equal(c.srcConnID) {
|
if !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
|
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if p.header.IsLongHeader {
|
if p.header.Type == protocol.PacketTypeRetry {
|
||||||
switch p.header.Type {
|
|
||||||
case protocol.PacketTypeRetry:
|
|
||||||
c.handleRetryPacket(p.header)
|
c.handleRetryPacket(p.header)
|
||||||
return nil
|
return nil
|
||||||
case protocol.PacketTypeHandshake, protocol.PacketType0RTT:
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is the first packet we are receiving
|
// this is the first packet we are receiving
|
||||||
|
@ -428,22 +331,6 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) handlePublicReset(p *receivedPacket) error {
|
|
||||||
cr := c.conn.RemoteAddr()
|
|
||||||
// check if the remote address and the connection ID match
|
|
||||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
|
||||||
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
|
|
||||||
return errors.New("Received a spoofed Public Reset")
|
|
||||||
}
|
|
||||||
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
|
||||||
}
|
|
||||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
|
||||||
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||||
// ignore delayed / duplicated version negotiation packets
|
// ignore delayed / duplicated version negotiation packets
|
||||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||||
|
@ -483,77 +370,56 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||||
func (c *client) handleRetryPacket(hdr *wire.Header) {
|
func (c *client) handleRetryPacket(hdr *wire.Header) {
|
||||||
c.logger.Debugf("<- Received Retry")
|
c.logger.Debugf("<- Received Retry")
|
||||||
hdr.Log(c.logger)
|
hdr.Log(c.logger)
|
||||||
// A server that performs multiple retries must use a source connection ID of at least 8 bytes.
|
|
||||||
// Only a server that won't send additional Retries can use shorter connection IDs.
|
|
||||||
if hdr.OrigDestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
|
||||||
c.logger.Debugf("Received a Retry with a too short Original Destination Connection ID: %d bytes, must have at least %d bytes.", hdr.OrigDestConnectionID.Len(), protocol.MinConnectionIDLenInitial)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
|
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
|
||||||
c.logger.Debugf("Received spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
|
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.numRetries++
|
if hdr.SrcConnectionID.Equal(c.destConnID) {
|
||||||
if c.numRetries > protocol.MaxRetries {
|
c.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
|
||||||
c.session.destroy(qerr.CryptoTooManyRejects)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// If a token is already set, this means that we already received a Retry from the server.
|
||||||
|
// Ignore this Retry packet.
|
||||||
|
if len(c.token) > 0 {
|
||||||
|
c.logger.Debugf("Ignoring Retry, since a Retry was already received.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.origDestConnID = c.destConnID
|
||||||
c.destConnID = hdr.SrcConnectionID
|
c.destConnID = hdr.SrcConnectionID
|
||||||
c.token = hdr.Token
|
c.token = hdr.Token
|
||||||
c.session.destroy(errCloseSessionForRetry)
|
c.session.destroy(errCloseSessionForRetry)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) createNewGQUICSession() error {
|
func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
|
||||||
|
params := &handshake.TransportParameters{
|
||||||
|
InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
|
||||||
|
InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
|
||||||
|
InitialMaxStreamDataUni: protocol.InitialMaxStreamData,
|
||||||
|
InitialMaxData: protocol.InitialMaxData,
|
||||||
|
IdleTimeout: c.config.IdleTimeout,
|
||||||
|
MaxBidiStreams: uint64(c.config.MaxIncomingStreams),
|
||||||
|
MaxUniStreams: uint64(c.config.MaxIncomingUniStreams),
|
||||||
|
DisableMigration: true,
|
||||||
|
}
|
||||||
|
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
runner := &runner{
|
runner := &runner{
|
||||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||||
removeConnectionIDImpl: c.closeCallback,
|
retireConnectionIDImpl: c.packetHandlers.Retire,
|
||||||
|
removeConnectionIDImpl: c.packetHandlers.Remove,
|
||||||
}
|
}
|
||||||
sess, err := newClientSession(
|
sess, err := newClientSession(
|
||||||
c.conn,
|
|
||||||
runner,
|
|
||||||
c.hostname,
|
|
||||||
c.version,
|
|
||||||
c.destConnID,
|
|
||||||
c.srcConnID,
|
|
||||||
c.tlsConf,
|
|
||||||
c.config,
|
|
||||||
c.initialVersion,
|
|
||||||
c.negotiatedVersions,
|
|
||||||
c.logger,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.session = sess
|
|
||||||
c.packetHandlers.Add(c.srcConnID, c)
|
|
||||||
if c.config.RequestConnectionIDOmission {
|
|
||||||
c.packetHandlers.Add(protocol.ConnectionID{}, c)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) createNewTLSSession(
|
|
||||||
paramsChan <-chan handshake.TransportParameters,
|
|
||||||
version protocol.VersionNumber,
|
|
||||||
) error {
|
|
||||||
c.mutex.Lock()
|
|
||||||
defer c.mutex.Unlock()
|
|
||||||
runner := &runner{
|
|
||||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
|
||||||
removeConnectionIDImpl: c.closeCallback,
|
|
||||||
}
|
|
||||||
sess, err := newTLSClientSession(
|
|
||||||
c.conn,
|
c.conn,
|
||||||
runner,
|
runner,
|
||||||
c.token,
|
c.token,
|
||||||
|
c.origDestConnID,
|
||||||
c.destConnID,
|
c.destConnID,
|
||||||
c.srcConnID,
|
c.srcConnID,
|
||||||
c.config,
|
c.config,
|
||||||
c.mintConf,
|
c.tlsConf,
|
||||||
paramsChan,
|
params,
|
||||||
1,
|
c.initialVersion,
|
||||||
c.logger,
|
c.logger,
|
||||||
c.version,
|
c.version,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,41 +1,108 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
type cryptoStream interface {
|
type cryptoStream interface {
|
||||||
StreamID() protocol.StreamID
|
// for receiving data
|
||||||
io.Reader
|
HandleCryptoFrame(*wire.CryptoFrame) error
|
||||||
|
GetCryptoData() []byte
|
||||||
|
Finish() error
|
||||||
|
// for sending data
|
||||||
io.Writer
|
io.Writer
|
||||||
handleStreamFrame(*wire.StreamFrame) error
|
HasData() bool
|
||||||
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
|
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
|
||||||
closeForShutdown(error)
|
|
||||||
setReadOffset(protocol.ByteCount)
|
|
||||||
// methods needed for flow control
|
|
||||||
getWindowUpdate() protocol.ByteCount
|
|
||||||
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type cryptoStreamImpl struct {
|
type cryptoStreamImpl struct {
|
||||||
*stream
|
queue *frameSorter
|
||||||
|
msgBuf []byte
|
||||||
|
|
||||||
|
highestOffset protocol.ByteCount
|
||||||
|
finished bool
|
||||||
|
|
||||||
|
writeOffset protocol.ByteCount
|
||||||
|
writeBuf []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ cryptoStream = &cryptoStreamImpl{}
|
func newCryptoStream() cryptoStream {
|
||||||
|
return &cryptoStreamImpl{
|
||||||
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream {
|
queue: newFrameSorter(),
|
||||||
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
}
|
||||||
return &cryptoStreamImpl{str}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetReadOffset sets the read offset.
|
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
||||||
// It is only needed for the crypto stream.
|
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
|
||||||
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
|
||||||
func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) {
|
return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset)
|
||||||
s.receiveStream.readOffset = offset
|
}
|
||||||
s.receiveStream.frameQueue.readPos = offset
|
if s.finished {
|
||||||
|
if highestOffset > s.highestOffset {
|
||||||
|
// reject crypto data received after this stream was already finished
|
||||||
|
return errors.New("received crypto data after change of encryption level")
|
||||||
|
}
|
||||||
|
// ignore data with a smaller offset than the highest received
|
||||||
|
// could e.g. be a retransmission
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset)
|
||||||
|
if err := s.queue.Push(f.Data, f.Offset, false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
data, _ := s.queue.Pop()
|
||||||
|
if data == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.msgBuf = append(s.msgBuf, data...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCryptoData retrieves data that was received in CRYPTO frames
|
||||||
|
func (s *cryptoStreamImpl) GetCryptoData() []byte {
|
||||||
|
if len(s.msgBuf) < 4 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
|
||||||
|
if len(s.msgBuf) < msgLen {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
msg := make([]byte, msgLen)
|
||||||
|
copy(msg, s.msgBuf[:msgLen])
|
||||||
|
s.msgBuf = s.msgBuf[msgLen:]
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *cryptoStreamImpl) Finish() error {
|
||||||
|
if s.queue.HasMoreData() {
|
||||||
|
return errors.New("encryption level changed, but crypto stream has more data to read")
|
||||||
|
}
|
||||||
|
s.finished = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writes writes data that should be sent out in CRYPTO frames
|
||||||
|
func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
|
||||||
|
s.writeBuf = append(s.writeBuf, p...)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *cryptoStreamImpl) HasData() bool {
|
||||||
|
return len(s.writeBuf) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
||||||
|
f := &wire.CryptoFrame{Offset: s.writeOffset}
|
||||||
|
n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
|
||||||
|
f.Data = s.writeBuf[:n]
|
||||||
|
s.writeBuf = s.writeBuf[n:]
|
||||||
|
s.writeOffset += n
|
||||||
|
return f
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cryptoDataHandler interface {
|
||||||
|
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type cryptoStreamManager struct {
|
||||||
|
cryptoHandler cryptoDataHandler
|
||||||
|
|
||||||
|
initialStream cryptoStream
|
||||||
|
handshakeStream cryptoStream
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCryptoStreamManager(
|
||||||
|
cryptoHandler cryptoDataHandler,
|
||||||
|
initialStream cryptoStream,
|
||||||
|
handshakeStream cryptoStream,
|
||||||
|
) *cryptoStreamManager {
|
||||||
|
return &cryptoStreamManager{
|
||||||
|
cryptoHandler: cryptoHandler,
|
||||||
|
initialStream: initialStream,
|
||||||
|
handshakeStream: handshakeStream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
|
||||||
|
var str cryptoStream
|
||||||
|
switch encLevel {
|
||||||
|
case protocol.EncryptionInitial:
|
||||||
|
str = m.initialStream
|
||||||
|
case protocol.EncryptionHandshake:
|
||||||
|
str = m.handshakeStream
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||||
|
}
|
||||||
|
if err := str.HandleCryptoFrame(frame); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
data := str.GetCryptoData()
|
||||||
|
if data == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
|
||||||
|
return true, str.Finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Binary file not shown.
Before Width: | Height: | Size: 17 KiB |
Binary file not shown.
|
@ -1,65 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
|
||||||
<svg width="601px" height="242px" viewBox="0 0 601 242" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
|
||||||
<!-- Generator: Sketch 3.7.2 (28276) - http://www.bohemiancoding.com/sketch -->
|
|
||||||
<title>quic</title>
|
|
||||||
<desc>Created with Sketch.</desc>
|
|
||||||
<defs></defs>
|
|
||||||
<g id="Page-1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
|
||||||
<g id="Group-4-Copy" transform="translate(586.087964, 15.755663) rotate(-26.000000) translate(-586.087964, -15.755663) translate(573.036290, 1.253803)">
|
|
||||||
<path d="M12.0494985,28.9428124 C-12.9338918,21.9172827 5.64340536,-9.86582259 25.705693,3.20698278 L12.0494985,28.9428124 L12.0494985,28.9428124 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.861962, 14.514517) scale(-1, 1) translate(-12.861962, -14.514517) "></path>
|
|
||||||
<path d="M12.1335889,20.1646293 C8.83752326,18.4210914 6.41944167,16.0531448 8.43561456,12.2408129 C10.3022111,8.71147453 13.769813,9.09473797 17.066293,10.8382759 L12.1335889,20.1646293 L12.1335889,20.1646293 Z" id="Shape" fill="#000000" transform="translate(12.382769, 14.854037) scale(-1, 1) translate(-12.382769, -14.854037) "></path>
|
|
||||||
</g>
|
|
||||||
<g id="Group-3-Copy" transform="translate(499.346306, 15.541651) rotate(25.000000) translate(-499.346306, -15.541651) translate(486.346306, 0.541651)">
|
|
||||||
<path d="M0.262217125,3.72242032 C20.0373679,-10.1674613 37.9073886,21.170644 14.9248407,29.0555125 L0.262217125,3.72242032 L0.262217125,3.72242032 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.997366, 14.636187) scale(-1, 1) translate(-12.997366, -14.636187) "></path>
|
|
||||||
<path d="M16.7752477,19.7502904 C20.0713133,18.0067525 22.4893949,15.6388059 20.473222,11.8264741 C18.6066255,8.29713567 15.1390235,8.68039911 11.8425436,10.423937 L16.7752477,19.7502904 L16.7752477,19.7502904 Z" id="Shape" fill="#000000" transform="translate(16.526068, 14.439699) scale(-1, 1) translate(-16.526068, -14.439699) "></path>
|
|
||||||
</g>
|
|
||||||
<g id="Group-3" transform="translate(1.896652, 21.000000)">
|
|
||||||
<path d="M0.262217125,3.72242032 C20.0373679,-10.1674613 37.9073886,21.170644 14.9248407,29.0555125 L0.262217125,3.72242032 L0.262217125,3.72242032 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.997366, 14.636187) scale(-1, 1) translate(-12.997366, -14.636187) "></path>
|
|
||||||
<path d="M16.7752477,19.7502904 C20.0713133,18.0067525 22.4893949,15.6388059 20.473222,11.8264741 C18.6066255,8.29713567 15.1390235,8.68039911 11.8425436,10.423937 L16.7752477,19.7502904 L16.7752477,19.7502904 Z" id="Shape" fill="#000000" transform="translate(16.526068, 14.439699) scale(-1, 1) translate(-16.526068, -14.439699) "></path>
|
|
||||||
</g>
|
|
||||||
<g id="Group-4" transform="translate(117.000000, 22.000000)">
|
|
||||||
<path d="M12.0494985,28.9428124 C-12.9338918,21.9172827 5.64340536,-9.86582259 25.705693,3.20698278 L12.0494985,28.9428124 L12.0494985,28.9428124 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.861962, 14.514517) scale(-1, 1) translate(-12.861962, -14.514517) "></path>
|
|
||||||
<path d="M12.1335889,20.1646293 C8.83752326,18.4210914 6.41944167,16.0531448 8.43561456,12.2408129 C10.3022111,8.71147453 13.769813,9.09473797 17.066293,10.8382759 L12.1335889,20.1646293 L12.1335889,20.1646293 Z" id="Shape" fill="#000000" transform="translate(12.382769, 14.854037) scale(-1, 1) translate(-12.382769, -14.854037) "></path>
|
|
||||||
</g>
|
|
||||||
<g id="Group-2" transform="translate(26.000000, 40.000000)">
|
|
||||||
<path d="M0.189217386,22.0262538 C4.67982195,48.130845 47.3500951,41.2287883 41.2211947,14.9008684 C35.725404,-8.70815982 -1.30323119,-2.17320732 0.189217386,22.0262538" id="Shape" stroke="#000000" stroke-width="2.9081" fill="#FFFFFF" transform="translate(20.983709, 19.323749) scale(-1, 1) translate(-20.983709, -19.323749) "></path>
|
|
||||||
<path d="M52.375353,26.2483668 C58.1955709,48.9748532 94.5815666,43.1562926 93.210105,20.3593686 C91.566837,-6.94349061 46.8107821,-1.67517201 52.375353,26.2483668" id="Shape" stroke="#000000" stroke-width="2.8214" fill="#FFFFFF" transform="translate(72.572290, 21.519163) scale(-1, 1) translate(-72.572290, -21.519163) "></path>
|
|
||||||
<path d="M44.873555,53.3511003 C44.8926146,56.7449499 45.6446396,60.55521 45.0028287,64.1657589 C44.1364461,65.803226 42.4368281,65.9764197 40.9717259,66.6381188 C38.9456089,66.3203209 37.2418475,64.9898788 36.429329,63.0946929 C35.9093337,58.9736786 36.6232396,54.9835954 36.754585,50.8609237 L44.873555,53.3511003 L44.873555,53.3511003 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(40.751798, 58.749521) scale(-1, 1) translate(-40.751798, -58.749521) "></path>
|
|
||||||
<g id="Group" transform="translate(82.660603, 22.581468) scale(-1, 1) translate(-82.660603, -22.581468) translate(76.238350, 15.744877)">
|
|
||||||
<ellipse id="Oval" fill="#000000" cx="6.34932869" cy="6.79681466" rx="6.14423095" ry="6.65511077"></ellipse>
|
|
||||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.1440443" cy="8.29879302" rx="1.44852865" ry="1.69133123"></ellipse>
|
|
||||||
</g>
|
|
||||||
<g id="Group" transform="translate(31.282584, 20.924112) scale(-1, 1) translate(-31.282584, -20.924112) translate(24.860332, 14.087521)">
|
|
||||||
<ellipse id="Oval" fill="#000000" cx="6.45167039" cy="6.79681466" rx="6.04188925" ry="6.65511077"></ellipse>
|
|
||||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.19998004" cy="8.29879302" rx="1.424497" ry="1.69133123"></ellipse>
|
|
||||||
</g>
|
|
||||||
<path d="M45.9112876,52.98151 C43.2305151,59.4783433 47.4062222,72.4699383 54.6799409,62.8875235 C54.1599456,58.7665092 54.8738514,54.776426 55.0051969,50.6537543 L45.9112876,52.98151 L45.9112876,52.98151 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(50.049883, 58.475140) scale(-1, 1) translate(-50.049883, -58.475140) "></path>
|
|
||||||
<g id="Group" transform="translate(44.127089, 43.712750) scale(-1, 1) translate(-44.127089, -43.712750) translate(27.346365, 32.732770)">
|
|
||||||
<path d="M7.63667953,7.73694953 C2.64016722,8.16288988 -1.44397093,14.1036805 1.15393372,18.8035261 C4.59418928,25.0285532 12.2731314,18.2528698 17.0558448,18.8876369 C22.5603366,19.0003371 27.0704151,24.7078549 31.4914107,19.9193407 C36.40837,14.593429 29.3745535,9.4063208 23.8771055,7.0872662 L7.63667953,7.73694953 L7.63667953,7.73694953 Z" id="Shape" stroke="#231F20" stroke-width="3" fill="#F6D2A2"></path>
|
|
||||||
<path d="M7.00771314,7.47674473 C6.63770854,-1.1792084 23.1412397,-2.2614615 25.0902897,4.98408215 C27.0343676,12.2126379 7.82023164,13.891539 7.00771314,7.47674473 C6.35927282,2.3542734 7.00771314,7.47674473 7.00771314,7.47674473 L7.00771314,7.47674473 Z" id="Shape" fill="#000000"></path>
|
|
||||||
</g>
|
|
||||||
</g>
|
|
||||||
<g id="Group-2-Copy" transform="translate(544.820291, 73.354278) scale(-1, 1) translate(-544.820291, -73.354278) translate(498.000000, 40.000000)">
|
|
||||||
<path d="M0.189217386,22.0262538 C4.67982195,48.130845 47.3500951,41.2287883 41.2211947,14.9008684 C35.725404,-8.70815982 -1.30323119,-2.17320732 0.189217386,22.0262538" id="Shape" stroke="#000000" stroke-width="2.9081" fill="#FFFFFF" transform="translate(20.983709, 19.323749) scale(-1, 1) translate(-20.983709, -19.323749) "></path>
|
|
||||||
<path d="M52.375353,26.2483668 C58.1955709,48.9748532 94.5815666,43.1562926 93.210105,20.3593686 C91.566837,-6.94349061 46.8107821,-1.67517201 52.375353,26.2483668" id="Shape" stroke="#000000" stroke-width="2.8214" fill="#FFFFFF" transform="translate(72.572290, 21.519163) scale(-1, 1) translate(-72.572290, -21.519163) "></path>
|
|
||||||
<path d="M44.873555,53.3511003 C44.8926146,56.7449499 45.6446396,60.55521 45.0028287,64.1657589 C44.1364461,65.803226 42.4368281,65.9764197 40.9717259,66.6381188 C38.9456089,66.3203209 37.2418475,64.9898788 36.429329,63.0946929 C35.9093337,58.9736786 36.6232396,54.9835954 36.754585,50.8609237 L44.873555,53.3511003 L44.873555,53.3511003 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(40.751798, 58.749521) scale(-1, 1) translate(-40.751798, -58.749521) "></path>
|
|
||||||
<g id="Group" transform="translate(82.660603, 22.581468) scale(-1, 1) translate(-82.660603, -22.581468) translate(76.238350, 15.744877)">
|
|
||||||
<ellipse id="Oval" fill="#000000" cx="6.34932869" cy="6.79681466" rx="6.14423095" ry="6.65511077"></ellipse>
|
|
||||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.1440443" cy="8.29879302" rx="1.44852865" ry="1.69133123"></ellipse>
|
|
||||||
</g>
|
|
||||||
<g id="Group" transform="translate(31.282584, 20.924112) scale(-1, 1) translate(-31.282584, -20.924112) translate(24.860332, 14.087521)">
|
|
||||||
<ellipse id="Oval" fill="#000000" cx="6.45167039" cy="6.79681466" rx="6.04188925" ry="6.65511077"></ellipse>
|
|
||||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.19998004" cy="8.29879302" rx="1.424497" ry="1.69133123"></ellipse>
|
|
||||||
</g>
|
|
||||||
<path d="M45.9112876,52.98151 C43.2305151,59.4783433 47.4062222,72.4699383 54.6799409,62.8875235 C54.1599456,58.7665092 54.8738514,54.776426 55.0051969,50.6537543 L45.9112876,52.98151 L45.9112876,52.98151 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(50.049883, 58.475140) scale(-1, 1) translate(-50.049883, -58.475140) "></path>
|
|
||||||
<g id="Group" transform="translate(44.127089, 43.712750) scale(-1, 1) translate(-44.127089, -43.712750) translate(27.346365, 32.732770)">
|
|
||||||
<path d="M7.63667953,7.73694953 C2.64016722,8.16288988 -1.44397093,14.1036805 1.15393372,18.8035261 C4.59418928,25.0285532 12.2731314,18.2528698 17.0558448,18.8876369 C22.5603366,19.0003371 27.0704151,24.7078549 31.4914107,19.9193407 C36.40837,14.593429 29.3745535,9.4063208 23.8771055,7.0872662 L7.63667953,7.73694953 L7.63667953,7.73694953 Z" id="Shape" stroke="#231F20" stroke-width="3" fill="#F6D2A2"></path>
|
|
||||||
<path d="M7.00771314,7.47674473 C6.63770854,-1.1792084 23.1412397,-2.2614615 25.0902897,4.98408215 C27.0343676,12.2126379 7.82023164,13.891539 7.00771314,7.47674473 C6.35927282,2.3542734 7.00771314,7.47674473 7.00771314,7.47674473 L7.00771314,7.47674473 Z" id="Shape" fill="#000000"></path>
|
|
||||||
</g>
|
|
||||||
</g>
|
|
||||||
<text id="QUIC" font-family="SourceCodePro-Light, Source Code Pro" font-size="288" font-weight="300" letter-spacing="-20" fill="#000000">
|
|
||||||
<tspan x="-13.6" y="197">QUI</tspan>
|
|
||||||
<tspan x="444.8" y="197">C</tspan>
|
|
||||||
</text>
|
|
||||||
</g>
|
|
||||||
</svg>
|
|
Before Width: | Height: | Size: 11 KiB |
|
@ -156,3 +156,8 @@ func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
|
||||||
s.readPos += protocol.ByteCount(len(data))
|
s.readPos += protocol.ByteCount(len(data))
|
||||||
return data, s.readPos >= s.finalOffset
|
return data, s.readPos >= s.finalOffset
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasMoreData says if there is any more data queued at *any* offset.
|
||||||
|
func (s *frameSorter) HasMoreData() bool {
|
||||||
|
return len(s.queue) > 0
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
type framer interface {
|
||||||
|
QueueControlFrame(wire.Frame)
|
||||||
|
AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount)
|
||||||
|
|
||||||
|
AddActiveStream(protocol.StreamID)
|
||||||
|
AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame
|
||||||
|
}
|
||||||
|
|
||||||
|
type framerI struct {
|
||||||
|
mutex sync.Mutex
|
||||||
|
|
||||||
|
streamGetter streamGetter
|
||||||
|
version protocol.VersionNumber
|
||||||
|
|
||||||
|
activeStreams map[protocol.StreamID]struct{}
|
||||||
|
streamQueue []protocol.StreamID
|
||||||
|
|
||||||
|
controlFrameMutex sync.Mutex
|
||||||
|
controlFrames []wire.Frame
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ framer = &framerI{}
|
||||||
|
|
||||||
|
func newFramer(
|
||||||
|
streamGetter streamGetter,
|
||||||
|
v protocol.VersionNumber,
|
||||||
|
) framer {
|
||||||
|
return &framerI{
|
||||||
|
streamGetter: streamGetter,
|
||||||
|
activeStreams: make(map[protocol.StreamID]struct{}),
|
||||||
|
version: v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *framerI) QueueControlFrame(frame wire.Frame) {
|
||||||
|
f.controlFrameMutex.Lock()
|
||||||
|
f.controlFrames = append(f.controlFrames, frame)
|
||||||
|
f.controlFrameMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *framerI) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) {
|
||||||
|
var length protocol.ByteCount
|
||||||
|
f.controlFrameMutex.Lock()
|
||||||
|
for len(f.controlFrames) > 0 {
|
||||||
|
frame := f.controlFrames[len(f.controlFrames)-1]
|
||||||
|
frameLen := frame.Length(f.version)
|
||||||
|
if length+frameLen > maxLen {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
frames = append(frames, frame)
|
||||||
|
length += frameLen
|
||||||
|
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
|
||||||
|
}
|
||||||
|
f.controlFrameMutex.Unlock()
|
||||||
|
return frames, length
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *framerI) AddActiveStream(id protocol.StreamID) {
|
||||||
|
f.mutex.Lock()
|
||||||
|
if _, ok := f.activeStreams[id]; !ok {
|
||||||
|
f.streamQueue = append(f.streamQueue, id)
|
||||||
|
f.activeStreams[id] = struct{}{}
|
||||||
|
}
|
||||||
|
f.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame {
|
||||||
|
var length protocol.ByteCount
|
||||||
|
f.mutex.Lock()
|
||||||
|
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
|
||||||
|
numActiveStreams := len(f.streamQueue)
|
||||||
|
for i := 0; i < numActiveStreams; i++ {
|
||||||
|
if maxLen-length < protocol.MinStreamFrameSize {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
id := f.streamQueue[0]
|
||||||
|
f.streamQueue = f.streamQueue[1:]
|
||||||
|
// This should never return an error. Better check it anyway.
|
||||||
|
// The stream will only be in the streamQueue, if it enqueued itself there.
|
||||||
|
str, err := f.streamGetter.GetOrOpenSendStream(id)
|
||||||
|
// The stream can be nil if it completed after it said it had data.
|
||||||
|
if str == nil || err != nil {
|
||||||
|
delete(f.activeStreams, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
frame, hasMoreData := str.popStreamFrame(maxLen - length)
|
||||||
|
if hasMoreData { // put the stream back in the queue (at the end)
|
||||||
|
f.streamQueue = append(f.streamQueue, id)
|
||||||
|
} else { // no more data to send. Stream is not active any more
|
||||||
|
delete(f.activeStreams, id)
|
||||||
|
}
|
||||||
|
if frame == nil { // can happen if the receiveStream was canceled after it said it had data
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
frames = append(frames, frame)
|
||||||
|
length += frame.Length(f.version)
|
||||||
|
}
|
||||||
|
f.mutex.Unlock()
|
||||||
|
return frames
|
||||||
|
}
|
|
@ -1,314 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
"golang.org/x/net/idna"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
type roundTripperOpts struct {
|
|
||||||
DisableCompression bool
|
|
||||||
}
|
|
||||||
|
|
||||||
var dialAddr = quic.DialAddr
|
|
||||||
|
|
||||||
// client is a HTTP2 client doing QUIC requests
|
|
||||||
type client struct {
|
|
||||||
mutex sync.RWMutex
|
|
||||||
|
|
||||||
tlsConf *tls.Config
|
|
||||||
config *quic.Config
|
|
||||||
opts *roundTripperOpts
|
|
||||||
|
|
||||||
hostname string
|
|
||||||
handshakeErr error
|
|
||||||
dialOnce sync.Once
|
|
||||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
|
||||||
|
|
||||||
session quic.Session
|
|
||||||
headerStream quic.Stream
|
|
||||||
headerErr *qerr.QuicError
|
|
||||||
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
|
|
||||||
requestWriter *requestWriter
|
|
||||||
|
|
||||||
responses map[protocol.StreamID]chan *http.Response
|
|
||||||
|
|
||||||
logger utils.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ http.RoundTripper = &client{}
|
|
||||||
|
|
||||||
var defaultQuicConfig = &quic.Config{
|
|
||||||
RequestConnectionIDOmission: true,
|
|
||||||
KeepAlive: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// newClient creates a new client
|
|
||||||
func newClient(
|
|
||||||
hostname string,
|
|
||||||
tlsConfig *tls.Config,
|
|
||||||
opts *roundTripperOpts,
|
|
||||||
quicConfig *quic.Config,
|
|
||||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
|
||||||
) *client {
|
|
||||||
config := defaultQuicConfig
|
|
||||||
if quicConfig != nil {
|
|
||||||
config = quicConfig
|
|
||||||
}
|
|
||||||
return &client{
|
|
||||||
hostname: authorityAddr("https", hostname),
|
|
||||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
|
||||||
tlsConf: tlsConfig,
|
|
||||||
config: config,
|
|
||||||
opts: opts,
|
|
||||||
headerErrored: make(chan struct{}),
|
|
||||||
dialer: dialer,
|
|
||||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// dial dials the connection
|
|
||||||
func (c *client) dial() error {
|
|
||||||
var err error
|
|
||||||
if c.dialer != nil {
|
|
||||||
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
|
||||||
} else {
|
|
||||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// once the version has been negotiated, open the header stream
|
|
||||||
c.headerStream, err = c.session.OpenStream()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.requestWriter = newRequestWriter(c.headerStream, c.logger)
|
|
||||||
go c.handleHeaderStream()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) handleHeaderStream() {
|
|
||||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
|
||||||
h2framer := http2.NewFramer(nil, c.headerStream)
|
|
||||||
|
|
||||||
var err error
|
|
||||||
for err == nil {
|
|
||||||
err = c.readResponse(h2framer, decoder)
|
|
||||||
}
|
|
||||||
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
|
|
||||||
c.logger.Debugf("Error handling header stream: %s", err)
|
|
||||||
}
|
|
||||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
|
|
||||||
// stop all running request
|
|
||||||
close(c.headerErrored)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
|
|
||||||
frame, err := h2framer.ReadFrame()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
hframe, ok := frame.(*http2.HeadersFrame)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("not a headers frame")
|
|
||||||
}
|
|
||||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
|
||||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot read header fields: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mutex.RLock()
|
|
||||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
|
||||||
c.mutex.RUnlock()
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
rsp, err := responseFromHeaders(mhframe)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
responseChan <- rsp
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Roundtrip executes a request and returns a response
|
|
||||||
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
// TODO: add port to address, if it doesn't have one
|
|
||||||
if req.URL.Scheme != "https" {
|
|
||||||
return nil, errors.New("quic http2: unsupported scheme")
|
|
||||||
}
|
|
||||||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
|
||||||
return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.dialOnce.Do(func() {
|
|
||||||
c.handshakeErr = c.dial()
|
|
||||||
})
|
|
||||||
|
|
||||||
if c.handshakeErr != nil {
|
|
||||||
return nil, c.handshakeErr
|
|
||||||
}
|
|
||||||
|
|
||||||
hasBody := (req.Body != nil)
|
|
||||||
|
|
||||||
responseChan := make(chan *http.Response)
|
|
||||||
dataStream, err := c.session.OpenStreamSync()
|
|
||||||
if err != nil {
|
|
||||||
_ = c.closeWithError(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.mutex.Lock()
|
|
||||||
c.responses[dataStream.StreamID()] = responseChan
|
|
||||||
c.mutex.Unlock()
|
|
||||||
|
|
||||||
var requestedGzip bool
|
|
||||||
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
|
||||||
requestedGzip = true
|
|
||||||
}
|
|
||||||
// TODO: add support for trailers
|
|
||||||
endStream := !hasBody
|
|
||||||
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
|
||||||
if err != nil {
|
|
||||||
_ = c.closeWithError(err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resc := make(chan error, 1)
|
|
||||||
if hasBody {
|
|
||||||
go func() {
|
|
||||||
resc <- c.writeRequestBody(dataStream, req.Body)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
var res *http.Response
|
|
||||||
|
|
||||||
var receivedResponse bool
|
|
||||||
var bodySent bool
|
|
||||||
|
|
||||||
if !hasBody {
|
|
||||||
bodySent = true
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := req.Context()
|
|
||||||
for !(bodySent && receivedResponse) {
|
|
||||||
select {
|
|
||||||
case res = <-responseChan:
|
|
||||||
receivedResponse = true
|
|
||||||
c.mutex.Lock()
|
|
||||||
delete(c.responses, dataStream.StreamID())
|
|
||||||
c.mutex.Unlock()
|
|
||||||
case err := <-resc:
|
|
||||||
bodySent = true
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
// error code 6 signals that stream was canceled
|
|
||||||
dataStream.CancelRead(6)
|
|
||||||
dataStream.CancelWrite(6)
|
|
||||||
c.mutex.Lock()
|
|
||||||
delete(c.responses, dataStream.StreamID())
|
|
||||||
c.mutex.Unlock()
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-c.headerErrored:
|
|
||||||
// an error occurred on the header stream
|
|
||||||
_ = c.closeWithError(c.headerErr)
|
|
||||||
return nil, c.headerErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: correctly set this variable
|
|
||||||
var streamEnded bool
|
|
||||||
isHead := (req.Method == "HEAD")
|
|
||||||
|
|
||||||
res = setLength(res, isHead, streamEnded)
|
|
||||||
|
|
||||||
if streamEnded || isHead {
|
|
||||||
res.Body = noBody
|
|
||||||
} else {
|
|
||||||
res.Body = dataStream
|
|
||||||
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
|
||||||
res.Header.Del("Content-Encoding")
|
|
||||||
res.Header.Del("Content-Length")
|
|
||||||
res.ContentLength = -1
|
|
||||||
res.Body = &gzipReader{body: res.Body}
|
|
||||||
res.Uncompressed = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Request = req
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
|
|
||||||
defer func() {
|
|
||||||
cerr := body.Close()
|
|
||||||
if err == nil {
|
|
||||||
// TODO: what to do with dataStream here? Maybe reset it?
|
|
||||||
err = cerr
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err = io.Copy(dataStream, body)
|
|
||||||
if err != nil {
|
|
||||||
// TODO: what to do with dataStream here? Maybe reset it?
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return dataStream.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) closeWithError(e error) error {
|
|
||||||
if c.session == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the client
|
|
||||||
func (c *client) Close() error {
|
|
||||||
if c.session == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return c.session.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// copied from net/transport.go
|
|
||||||
|
|
||||||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
|
||||||
// and returns a host:port. The port 443 is added if needed.
|
|
||||||
func authorityAddr(scheme string, authority string) (addr string) {
|
|
||||||
host, port, err := net.SplitHostPort(authority)
|
|
||||||
if err != nil { // authority didn't have a port
|
|
||||||
port = "443"
|
|
||||||
if scheme == "http" {
|
|
||||||
port = "80"
|
|
||||||
}
|
|
||||||
host = authority
|
|
||||||
}
|
|
||||||
if a, err := idna.ToASCII(host); err == nil {
|
|
||||||
host = a
|
|
||||||
}
|
|
||||||
// IPv6 address literal, without a port:
|
|
||||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
|
||||||
return host + ":" + port
|
|
||||||
}
|
|
||||||
return net.JoinHostPort(host, port)
|
|
||||||
}
|
|
|
@ -1,35 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
// copied from net/transport.go
|
|
||||||
|
|
||||||
// gzipReader wraps a response body so it can lazily
|
|
||||||
// call gzip.NewReader on the first call to Read
|
|
||||||
import (
|
|
||||||
"compress/gzip"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// call gzip.NewReader on the first call to Read
|
|
||||||
type gzipReader struct {
|
|
||||||
body io.ReadCloser // underlying Response.Body
|
|
||||||
zr *gzip.Reader // lazily-initialized gzip reader
|
|
||||||
zerr error // sticky error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (gz *gzipReader) Read(p []byte) (n int, err error) {
|
|
||||||
if gz.zerr != nil {
|
|
||||||
return 0, gz.zerr
|
|
||||||
}
|
|
||||||
if gz.zr == nil {
|
|
||||||
gz.zr, err = gzip.NewReader(gz.body)
|
|
||||||
if err != nil {
|
|
||||||
gz.zerr = err
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return gz.zr.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (gz *gzipReader) Close() error {
|
|
||||||
return gz.body.Close()
|
|
||||||
}
|
|
|
@ -1,77 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
)
|
|
||||||
|
|
||||||
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
|
|
||||||
var path, authority, method, contentLengthStr string
|
|
||||||
httpHeaders := http.Header{}
|
|
||||||
|
|
||||||
for _, h := range headers {
|
|
||||||
switch h.Name {
|
|
||||||
case ":path":
|
|
||||||
path = h.Value
|
|
||||||
case ":method":
|
|
||||||
method = h.Value
|
|
||||||
case ":authority":
|
|
||||||
authority = h.Value
|
|
||||||
case "content-length":
|
|
||||||
contentLengthStr = h.Value
|
|
||||||
default:
|
|
||||||
if !h.IsPseudo() {
|
|
||||||
httpHeaders.Add(h.Name, h.Value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
|
|
||||||
if len(httpHeaders["Cookie"]) > 0 {
|
|
||||||
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
|
|
||||||
return nil, errors.New(":path, :authority and :method must not be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
u, err := url.Parse(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var contentLength int64
|
|
||||||
if len(contentLengthStr) > 0 {
|
|
||||||
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &http.Request{
|
|
||||||
Method: method,
|
|
||||||
URL: u,
|
|
||||||
Proto: "HTTP/2.0",
|
|
||||||
ProtoMajor: 2,
|
|
||||||
ProtoMinor: 0,
|
|
||||||
Header: httpHeaders,
|
|
||||||
Body: nil,
|
|
||||||
ContentLength: contentLength,
|
|
||||||
Host: authority,
|
|
||||||
RequestURI: path,
|
|
||||||
TLS: &tls.ConnectionState{},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func hostnameFromRequest(req *http.Request) string {
|
|
||||||
if req.URL != nil {
|
|
||||||
return req.URL.Host
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
|
@ -1,29 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
type requestBody struct {
|
|
||||||
requestRead bool
|
|
||||||
dataStream quic.Stream
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure the requestBody can be used as a http.Request.Body
|
|
||||||
var _ io.ReadCloser = &requestBody{}
|
|
||||||
|
|
||||||
func newRequestBody(stream quic.Stream) *requestBody {
|
|
||||||
return &requestBody{dataStream: stream}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *requestBody) Read(p []byte) (int, error) {
|
|
||||||
b.requestRead = true
|
|
||||||
return b.dataStream.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *requestBody) Close() error {
|
|
||||||
// stream's Close() closes the write side, not the read side
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,203 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/net/http/httpguts"
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type requestWriter struct {
|
|
||||||
mutex sync.Mutex
|
|
||||||
headerStream quic.Stream
|
|
||||||
|
|
||||||
henc *hpack.Encoder
|
|
||||||
hbuf bytes.Buffer // HPACK encoder writes into this
|
|
||||||
|
|
||||||
logger utils.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
const defaultUserAgent = "quic-go"
|
|
||||||
|
|
||||||
func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
|
|
||||||
rw := &requestWriter{
|
|
||||||
headerStream: headerStream,
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
rw.henc = hpack.NewEncoder(&rw.hbuf)
|
|
||||||
return rw
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
|
|
||||||
// TODO: add support for trailers
|
|
||||||
// TODO: add support for gzip compression
|
|
||||||
// TODO: write continuation frames, if the header frame is too long
|
|
||||||
|
|
||||||
w.mutex.Lock()
|
|
||||||
defer w.mutex.Unlock()
|
|
||||||
|
|
||||||
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
|
|
||||||
h2framer := http2.NewFramer(w.headerStream, nil)
|
|
||||||
return h2framer.WriteHeaders(http2.HeadersFrameParam{
|
|
||||||
StreamID: uint32(dataStreamID),
|
|
||||||
EndHeaders: true,
|
|
||||||
EndStream: endStream,
|
|
||||||
BlockFragment: w.hbuf.Bytes(),
|
|
||||||
Priority: http2.PriorityParam{Weight: 0xff},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// the rest of this files is copied from http2.Transport
|
|
||||||
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
|
|
||||||
w.hbuf.Reset()
|
|
||||||
|
|
||||||
host := req.Host
|
|
||||||
if host == "" {
|
|
||||||
host = req.URL.Host
|
|
||||||
}
|
|
||||||
host, err := httpguts.PunycodeHostPort(host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var path string
|
|
||||||
if req.Method != "CONNECT" {
|
|
||||||
path = req.URL.RequestURI()
|
|
||||||
if !validPseudoPath(path) {
|
|
||||||
orig := path
|
|
||||||
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
|
|
||||||
if !validPseudoPath(path) {
|
|
||||||
if req.URL.Opaque != "" {
|
|
||||||
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("invalid request :path %q", orig)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for any invalid headers and return an error before we
|
|
||||||
// potentially pollute our hpack state. (We want to be able to
|
|
||||||
// continue to reuse the hpack encoder for future requests)
|
|
||||||
for k, vv := range req.Header {
|
|
||||||
if !httpguts.ValidHeaderFieldName(k) {
|
|
||||||
return nil, fmt.Errorf("invalid HTTP header name %q", k)
|
|
||||||
}
|
|
||||||
for _, v := range vv {
|
|
||||||
if !httpguts.ValidHeaderFieldValue(v) {
|
|
||||||
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 8.1.2.3 Request Pseudo-Header Fields
|
|
||||||
// The :path pseudo-header field includes the path and query parts of the
|
|
||||||
// target URI (the path-absolute production and optionally a '?' character
|
|
||||||
// followed by the query production (see Sections 3.3 and 3.4 of
|
|
||||||
// [RFC3986]).
|
|
||||||
w.writeHeader(":authority", host)
|
|
||||||
w.writeHeader(":method", req.Method)
|
|
||||||
if req.Method != "CONNECT" {
|
|
||||||
w.writeHeader(":path", path)
|
|
||||||
w.writeHeader(":scheme", req.URL.Scheme)
|
|
||||||
}
|
|
||||||
if trailers != "" {
|
|
||||||
w.writeHeader("trailer", trailers)
|
|
||||||
}
|
|
||||||
|
|
||||||
var didUA bool
|
|
||||||
for k, vv := range req.Header {
|
|
||||||
lowKey := strings.ToLower(k)
|
|
||||||
switch lowKey {
|
|
||||||
case "host", "content-length":
|
|
||||||
// Host is :authority, already sent.
|
|
||||||
// Content-Length is automatic, set below.
|
|
||||||
continue
|
|
||||||
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
|
|
||||||
// Per 8.1.2.2 Connection-Specific Header
|
|
||||||
// Fields, don't send connection-specific
|
|
||||||
// fields. We have already checked if any
|
|
||||||
// are error-worthy so just ignore the rest.
|
|
||||||
continue
|
|
||||||
case "user-agent":
|
|
||||||
// Match Go's http1 behavior: at most one
|
|
||||||
// User-Agent. If set to nil or empty string,
|
|
||||||
// then omit it. Otherwise if not mentioned,
|
|
||||||
// include the default (below).
|
|
||||||
didUA = true
|
|
||||||
if len(vv) < 1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
vv = vv[:1]
|
|
||||||
if vv[0] == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, v := range vv {
|
|
||||||
w.writeHeader(lowKey, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if shouldSendReqContentLength(req.Method, contentLength) {
|
|
||||||
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
|
|
||||||
}
|
|
||||||
if addGzipHeader {
|
|
||||||
w.writeHeader("accept-encoding", "gzip")
|
|
||||||
}
|
|
||||||
if !didUA {
|
|
||||||
w.writeHeader("user-agent", defaultUserAgent)
|
|
||||||
}
|
|
||||||
return w.hbuf.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *requestWriter) writeHeader(name, value string) {
|
|
||||||
w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
|
|
||||||
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldSendReqContentLength reports whether the http2.Transport should send
|
|
||||||
// a "content-length" request header. This logic is basically a copy of the net/http
|
|
||||||
// transferWriter.shouldSendContentLength.
|
|
||||||
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
|
||||||
// -1 means unknown.
|
|
||||||
func shouldSendReqContentLength(method string, contentLength int64) bool {
|
|
||||||
if contentLength > 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if contentLength < 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// For zero bodies, whether we send a content-length depends on the method.
|
|
||||||
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
|
||||||
switch method {
|
|
||||||
case "POST", "PUT", "PATCH":
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func validPseudoPath(v string) bool {
|
|
||||||
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
|
|
||||||
}
|
|
||||||
|
|
||||||
// actualContentLength returns a sanitized version of
|
|
||||||
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
|
||||||
// means unknown.
|
|
||||||
func actualContentLength(req *http.Request) int64 {
|
|
||||||
if req.Body == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
if req.ContentLength != 0 {
|
|
||||||
return req.ContentLength
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
|
@ -1,95 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"net/textproto"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// copied from net/http2/transport.go
|
|
||||||
|
|
||||||
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
|
|
||||||
var noBody = ioutil.NopCloser(bytes.NewReader(nil))
|
|
||||||
|
|
||||||
// from the handleResponse function
|
|
||||||
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
|
|
||||||
if f.Truncated {
|
|
||||||
return nil, errResponseHeaderListSize
|
|
||||||
}
|
|
||||||
|
|
||||||
status := f.PseudoValue("status")
|
|
||||||
if status == "" {
|
|
||||||
return nil, errors.New("missing status pseudo header")
|
|
||||||
}
|
|
||||||
statusCode, err := strconv.Atoi(status)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("malformed non-numeric status pseudo header")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: handle statusCode == 100
|
|
||||||
|
|
||||||
header := make(http.Header)
|
|
||||||
res := &http.Response{
|
|
||||||
Proto: "HTTP/2.0",
|
|
||||||
ProtoMajor: 2,
|
|
||||||
Header: header,
|
|
||||||
StatusCode: statusCode,
|
|
||||||
Status: status + " " + http.StatusText(statusCode),
|
|
||||||
}
|
|
||||||
for _, hf := range f.RegularFields() {
|
|
||||||
key := http.CanonicalHeaderKey(hf.Name)
|
|
||||||
if key == "Trailer" {
|
|
||||||
t := res.Trailer
|
|
||||||
if t == nil {
|
|
||||||
t = make(http.Header)
|
|
||||||
res.Trailer = t
|
|
||||||
}
|
|
||||||
foreachHeaderElement(hf.Value, func(v string) {
|
|
||||||
t[http.CanonicalHeaderKey(v)] = nil
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
header[key] = append(header[key], hf.Value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// continuation of the handleResponse function
|
|
||||||
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
|
|
||||||
if !streamEnded || isHead {
|
|
||||||
res.ContentLength = -1
|
|
||||||
if clens := res.Header["Content-Length"]; len(clens) == 1 {
|
|
||||||
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
|
|
||||||
res.ContentLength = clen64
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// copied from net/http/server.go
|
|
||||||
|
|
||||||
// foreachHeaderElement splits v according to the "#rule" construction
|
|
||||||
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
|
|
||||||
func foreachHeaderElement(v string, fn func(string)) {
|
|
||||||
v = textproto.TrimString(v)
|
|
||||||
if v == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !strings.Contains(v, ",") {
|
|
||||||
fn(v)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, f := range strings.Split(v, ",") {
|
|
||||||
if f = textproto.TrimString(f); f != "" {
|
|
||||||
fn(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,114 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
)
|
|
||||||
|
|
||||||
type responseWriter struct {
|
|
||||||
dataStreamID protocol.StreamID
|
|
||||||
dataStream quic.Stream
|
|
||||||
|
|
||||||
headerStream quic.Stream
|
|
||||||
headerStreamMutex *sync.Mutex
|
|
||||||
|
|
||||||
header http.Header
|
|
||||||
status int // status code passed to WriteHeader
|
|
||||||
headerWritten bool
|
|
||||||
|
|
||||||
logger utils.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func newResponseWriter(
|
|
||||||
headerStream quic.Stream,
|
|
||||||
headerStreamMutex *sync.Mutex,
|
|
||||||
dataStream quic.Stream,
|
|
||||||
dataStreamID protocol.StreamID,
|
|
||||||
logger utils.Logger,
|
|
||||||
) *responseWriter {
|
|
||||||
return &responseWriter{
|
|
||||||
header: http.Header{},
|
|
||||||
headerStream: headerStream,
|
|
||||||
headerStreamMutex: headerStreamMutex,
|
|
||||||
dataStream: dataStream,
|
|
||||||
dataStreamID: dataStreamID,
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseWriter) Header() http.Header {
|
|
||||||
return w.header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseWriter) WriteHeader(status int) {
|
|
||||||
if w.headerWritten {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.headerWritten = true
|
|
||||||
w.status = status
|
|
||||||
|
|
||||||
var headers bytes.Buffer
|
|
||||||
enc := hpack.NewEncoder(&headers)
|
|
||||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
|
|
||||||
|
|
||||||
for k, v := range w.header {
|
|
||||||
for index := range v {
|
|
||||||
enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
w.logger.Infof("Responding with %d", status)
|
|
||||||
w.headerStreamMutex.Lock()
|
|
||||||
defer w.headerStreamMutex.Unlock()
|
|
||||||
h2framer := http2.NewFramer(w.headerStream, nil)
|
|
||||||
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
|
|
||||||
StreamID: uint32(w.dataStreamID),
|
|
||||||
EndHeaders: true,
|
|
||||||
BlockFragment: headers.Bytes(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
w.logger.Errorf("could not write h2 header: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseWriter) Write(p []byte) (int, error) {
|
|
||||||
if !w.headerWritten {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
}
|
|
||||||
if !bodyAllowedForStatus(w.status) {
|
|
||||||
return 0, http.ErrBodyNotAllowed
|
|
||||||
}
|
|
||||||
return w.dataStream.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseWriter) Flush() {}
|
|
||||||
|
|
||||||
// This is a NOP. Use http.Request.Context
|
|
||||||
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
|
||||||
|
|
||||||
// test that we implement http.Flusher
|
|
||||||
var _ http.Flusher = &responseWriter{}
|
|
||||||
|
|
||||||
// copied from http2/http2.go
|
|
||||||
// bodyAllowedForStatus reports whether a given response status code
|
|
||||||
// permits a body. See RFC 2616, section 4.4.
|
|
||||||
func bodyAllowedForStatus(status int) bool {
|
|
||||||
switch {
|
|
||||||
case status >= 100 && status <= 199:
|
|
||||||
return false
|
|
||||||
case status == 204:
|
|
||||||
return false
|
|
||||||
case status == 304:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
9
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_closenotifier.go
generated
vendored
9
vendor/github.com/lucas-clemente/quic-go/h2quic/response_writer_closenotifier.go
generated
vendored
|
@ -1,9 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
// The CloseNotifier is a deprecated interface, and staticcheck will report that from Go 1.11.
|
|
||||||
// By defining it in a separate file, we can exclude this file from staticcheck.
|
|
||||||
|
|
||||||
// test that we implement http.CloseNotifier
|
|
||||||
var _ http.CloseNotifier = &responseWriter{}
|
|
|
@ -1,179 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
|
|
||||||
"golang.org/x/net/http/httpguts"
|
|
||||||
)
|
|
||||||
|
|
||||||
type roundTripCloser interface {
|
|
||||||
http.RoundTripper
|
|
||||||
io.Closer
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoundTripper implements the http.RoundTripper interface
|
|
||||||
type RoundTripper struct {
|
|
||||||
mutex sync.Mutex
|
|
||||||
|
|
||||||
// DisableCompression, if true, prevents the Transport from
|
|
||||||
// requesting compression with an "Accept-Encoding: gzip"
|
|
||||||
// request header when the Request contains no existing
|
|
||||||
// Accept-Encoding value. If the Transport requests gzip on
|
|
||||||
// its own and gets a gzipped response, it's transparently
|
|
||||||
// decoded in the Response.Body. However, if the user
|
|
||||||
// explicitly requested gzip it is not automatically
|
|
||||||
// uncompressed.
|
|
||||||
DisableCompression bool
|
|
||||||
|
|
||||||
// TLSClientConfig specifies the TLS configuration to use with
|
|
||||||
// tls.Client. If nil, the default configuration is used.
|
|
||||||
TLSClientConfig *tls.Config
|
|
||||||
|
|
||||||
// QuicConfig is the quic.Config used for dialing new connections.
|
|
||||||
// If nil, reasonable default values will be used.
|
|
||||||
QuicConfig *quic.Config
|
|
||||||
|
|
||||||
// Dial specifies an optional dial function for creating QUIC
|
|
||||||
// connections for requests.
|
|
||||||
// If Dial is nil, quic.DialAddr will be used.
|
|
||||||
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
|
||||||
|
|
||||||
clients map[string]roundTripCloser
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
|
||||||
type RoundTripOpt struct {
|
|
||||||
// OnlyCachedConn controls whether the RoundTripper may
|
|
||||||
// create a new QUIC connection. If set true and
|
|
||||||
// no cached connection is available, RoundTrip
|
|
||||||
// will return ErrNoCachedConn.
|
|
||||||
OnlyCachedConn bool
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ roundTripCloser = &RoundTripper{}
|
|
||||||
|
|
||||||
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
|
||||||
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
|
|
||||||
|
|
||||||
// RoundTripOpt is like RoundTrip, but takes options.
|
|
||||||
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
|
||||||
if req.URL == nil {
|
|
||||||
closeRequestBody(req)
|
|
||||||
return nil, errors.New("quic: nil Request.URL")
|
|
||||||
}
|
|
||||||
if req.URL.Host == "" {
|
|
||||||
closeRequestBody(req)
|
|
||||||
return nil, errors.New("quic: no Host in request URL")
|
|
||||||
}
|
|
||||||
if req.Header == nil {
|
|
||||||
closeRequestBody(req)
|
|
||||||
return nil, errors.New("quic: nil Request.Header")
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.URL.Scheme == "https" {
|
|
||||||
for k, vv := range req.Header {
|
|
||||||
if !httpguts.ValidHeaderFieldName(k) {
|
|
||||||
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
|
|
||||||
}
|
|
||||||
for _, v := range vv {
|
|
||||||
if !httpguts.ValidHeaderFieldValue(v) {
|
|
||||||
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
closeRequestBody(req)
|
|
||||||
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Method != "" && !validMethod(req.Method) {
|
|
||||||
closeRequestBody(req)
|
|
||||||
return nil, fmt.Errorf("quic: invalid method %q", req.Method)
|
|
||||||
}
|
|
||||||
|
|
||||||
hostname := authorityAddr("https", hostnameFromRequest(req))
|
|
||||||
cl, err := r.getClient(hostname, opt.OnlyCachedConn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return cl.RoundTrip(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoundTrip does a round trip.
|
|
||||||
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
return r.RoundTripOpt(req, RoundTripOpt{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
|
|
||||||
r.mutex.Lock()
|
|
||||||
defer r.mutex.Unlock()
|
|
||||||
|
|
||||||
if r.clients == nil {
|
|
||||||
r.clients = make(map[string]roundTripCloser)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, ok := r.clients[hostname]
|
|
||||||
if !ok {
|
|
||||||
if onlyCached {
|
|
||||||
return nil, ErrNoCachedConn
|
|
||||||
}
|
|
||||||
client = newClient(
|
|
||||||
hostname,
|
|
||||||
r.TLSClientConfig,
|
|
||||||
&roundTripperOpts{DisableCompression: r.DisableCompression},
|
|
||||||
r.QuicConfig,
|
|
||||||
r.Dial,
|
|
||||||
)
|
|
||||||
r.clients[hostname] = client
|
|
||||||
}
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the QUIC connections that this RoundTripper has used
|
|
||||||
func (r *RoundTripper) Close() error {
|
|
||||||
r.mutex.Lock()
|
|
||||||
defer r.mutex.Unlock()
|
|
||||||
for _, client := range r.clients {
|
|
||||||
if err := client.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.clients = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeRequestBody(req *http.Request) {
|
|
||||||
if req.Body != nil {
|
|
||||||
req.Body.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func validMethod(method string) bool {
|
|
||||||
/*
|
|
||||||
Method = "OPTIONS" ; Section 9.2
|
|
||||||
| "GET" ; Section 9.3
|
|
||||||
| "HEAD" ; Section 9.4
|
|
||||||
| "POST" ; Section 9.5
|
|
||||||
| "PUT" ; Section 9.6
|
|
||||||
| "DELETE" ; Section 9.7
|
|
||||||
| "TRACE" ; Section 9.8
|
|
||||||
| "CONNECT" ; Section 9.9
|
|
||||||
| extension-method
|
|
||||||
extension-method = token
|
|
||||||
token = 1*<any CHAR except CTLs or separators>
|
|
||||||
*/
|
|
||||||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// copied from net/http/http.go
|
|
||||||
func isNotToken(r rune) bool {
|
|
||||||
return !httpguts.IsTokenRune(r)
|
|
||||||
}
|
|
|
@ -1,402 +0,0 @@
|
||||||
package h2quic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
quic "github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"golang.org/x/net/http2/hpack"
|
|
||||||
)
|
|
||||||
|
|
||||||
type streamCreator interface {
|
|
||||||
quic.Session
|
|
||||||
GetOrOpenStream(protocol.StreamID) (quic.Stream, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type remoteCloser interface {
|
|
||||||
CloseRemote(protocol.ByteCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
// allows mocking of quic.Listen and quic.ListenAddr
|
|
||||||
var (
|
|
||||||
quicListen = quic.Listen
|
|
||||||
quicListenAddr = quic.ListenAddr
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server is a HTTP2 server listening for QUIC connections.
|
|
||||||
type Server struct {
|
|
||||||
*http.Server
|
|
||||||
|
|
||||||
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
|
|
||||||
// If nil, it uses reasonable default values.
|
|
||||||
QuicConfig *quic.Config
|
|
||||||
|
|
||||||
// Private flag for demo, do not use
|
|
||||||
CloseAfterFirstRequest bool
|
|
||||||
|
|
||||||
port uint32 // used atomically
|
|
||||||
|
|
||||||
listenerMutex sync.Mutex
|
|
||||||
listener quic.Listener
|
|
||||||
closed bool
|
|
||||||
|
|
||||||
supportedVersionsAsString string
|
|
||||||
|
|
||||||
logger utils.Logger // will be set by Server.serveImpl()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
|
||||||
func (s *Server) ListenAndServe() error {
|
|
||||||
if s.Server == nil {
|
|
||||||
return errors.New("use of h2quic.Server without http.Server")
|
|
||||||
}
|
|
||||||
return s.serveImpl(s.TLSConfig, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
|
||||||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
|
||||||
var err error
|
|
||||||
certs := make([]tls.Certificate, 1)
|
|
||||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// We currently only use the cert-related stuff from tls.Config,
|
|
||||||
// so we don't need to make a full copy.
|
|
||||||
config := &tls.Config{
|
|
||||||
Certificates: certs,
|
|
||||||
}
|
|
||||||
return s.serveImpl(config, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serve an existing UDP connection.
|
|
||||||
func (s *Server) Serve(conn net.PacketConn) error {
|
|
||||||
return s.serveImpl(s.TLSConfig, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
|
||||||
if s.Server == nil {
|
|
||||||
return errors.New("use of h2quic.Server without http.Server")
|
|
||||||
}
|
|
||||||
s.logger = utils.DefaultLogger.WithPrefix("server")
|
|
||||||
s.listenerMutex.Lock()
|
|
||||||
if s.closed {
|
|
||||||
s.listenerMutex.Unlock()
|
|
||||||
return errors.New("Server is already closed")
|
|
||||||
}
|
|
||||||
if s.listener != nil {
|
|
||||||
s.listenerMutex.Unlock()
|
|
||||||
return errors.New("ListenAndServe may only be called once")
|
|
||||||
}
|
|
||||||
|
|
||||||
var ln quic.Listener
|
|
||||||
var err error
|
|
||||||
if conn == nil {
|
|
||||||
ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
|
|
||||||
} else {
|
|
||||||
ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
s.listenerMutex.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.listener = ln
|
|
||||||
s.listenerMutex.Unlock()
|
|
||||||
|
|
||||||
for {
|
|
||||||
sess, err := ln.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
go s.handleHeaderStream(sess.(streamCreator))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleHeaderStream(session streamCreator) {
|
|
||||||
stream, err := session.AcceptStream()
|
|
||||||
if err != nil {
|
|
||||||
session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hpackDecoder := hpack.NewDecoder(4096, nil)
|
|
||||||
h2framer := http2.NewFramer(nil, stream)
|
|
||||||
|
|
||||||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
|
||||||
for {
|
|
||||||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
|
||||||
// QuicErrors must originate from stream.Read() returning an error.
|
|
||||||
// In this case, the session has already logged the error, so we don't
|
|
||||||
// need to log it again.
|
|
||||||
errorCode := qerr.InternalError
|
|
||||||
if qerr, ok := err.(*qerr.QuicError); !ok {
|
|
||||||
errorCode = qerr.ErrorCode
|
|
||||||
s.logger.Errorf("error handling h2 request: %s", err.Error())
|
|
||||||
}
|
|
||||||
session.CloseWithError(quic.ErrorCode(errorCode), err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
|
||||||
h2frame, err := h2framer.ReadFrame()
|
|
||||||
if err != nil {
|
|
||||||
return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
|
||||||
}
|
|
||||||
var h2headersFrame *http2.HeadersFrame
|
|
||||||
switch f := h2frame.(type) {
|
|
||||||
case *http2.PriorityFrame:
|
|
||||||
// ignore PRIORITY frames
|
|
||||||
s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
|
|
||||||
return nil
|
|
||||||
case *http2.HeadersFrame:
|
|
||||||
h2headersFrame = f
|
|
||||||
default:
|
|
||||||
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !h2headersFrame.HeadersEnded() {
|
|
||||||
return errors.New("http2 header continuation not implemented")
|
|
||||||
}
|
|
||||||
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := requestFromHeaders(headers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.logger.Debug() {
|
|
||||||
s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
|
||||||
} else {
|
|
||||||
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
|
|
||||||
}
|
|
||||||
|
|
||||||
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request
|
|
||||||
if dataStream == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleRequest should be as non-blocking as possible to minimize
|
|
||||||
// head-of-line blocking. Potentially blocking code is run in a separate
|
|
||||||
// goroutine, enabling handleRequest to return before the code is executed.
|
|
||||||
go func() {
|
|
||||||
streamEnded := h2headersFrame.StreamEnded()
|
|
||||||
if streamEnded {
|
|
||||||
dataStream.(remoteCloser).CloseRemote(0)
|
|
||||||
streamEnded = true
|
|
||||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
|
||||||
}
|
|
||||||
|
|
||||||
req = req.WithContext(dataStream.Context())
|
|
||||||
reqBody := newRequestBody(dataStream)
|
|
||||||
req.Body = reqBody
|
|
||||||
|
|
||||||
req.RemoteAddr = session.RemoteAddr().String()
|
|
||||||
|
|
||||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
|
|
||||||
|
|
||||||
handler := s.Handler
|
|
||||||
if handler == nil {
|
|
||||||
handler = http.DefaultServeMux
|
|
||||||
}
|
|
||||||
panicked := false
|
|
||||||
func() {
|
|
||||||
defer func() {
|
|
||||||
if p := recover(); p != nil {
|
|
||||||
// Copied from net/http/server.go
|
|
||||||
const size = 64 << 10
|
|
||||||
buf := make([]byte, size)
|
|
||||||
buf = buf[:runtime.Stack(buf, false)]
|
|
||||||
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
|
|
||||||
panicked = true
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
handler.ServeHTTP(responseWriter, req)
|
|
||||||
}()
|
|
||||||
if panicked {
|
|
||||||
responseWriter.WriteHeader(500)
|
|
||||||
} else {
|
|
||||||
responseWriter.WriteHeader(200)
|
|
||||||
}
|
|
||||||
if responseWriter.dataStream != nil {
|
|
||||||
if !streamEnded && !reqBody.requestRead {
|
|
||||||
// in gQUIC, the error code doesn't matter, so just use 0 here
|
|
||||||
responseWriter.dataStream.CancelRead(0)
|
|
||||||
}
|
|
||||||
responseWriter.dataStream.Close()
|
|
||||||
}
|
|
||||||
if s.CloseAfterFirstRequest {
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
session.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
|
|
||||||
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
|
||||||
func (s *Server) Close() error {
|
|
||||||
s.listenerMutex.Lock()
|
|
||||||
defer s.listenerMutex.Unlock()
|
|
||||||
s.closed = true
|
|
||||||
if s.listener != nil {
|
|
||||||
err := s.listener.Close()
|
|
||||||
s.listener = nil
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
|
|
||||||
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
|
||||||
func (s *Server) CloseGracefully(timeout time.Duration) error {
|
|
||||||
// TODO: implement
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
|
|
||||||
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
|
|
||||||
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
|
|
||||||
func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
|
||||||
port := atomic.LoadUint32(&s.port)
|
|
||||||
|
|
||||||
if port == 0 {
|
|
||||||
// Extract port from s.Server.Addr
|
|
||||||
_, portStr, err := net.SplitHostPort(s.Server.Addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
portInt, err := net.LookupPort("tcp", portStr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
port = uint32(portInt)
|
|
||||||
atomic.StoreUint32(&s.port, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.supportedVersionsAsString == "" {
|
|
||||||
var versions []string
|
|
||||||
for _, v := range protocol.SupportedVersions {
|
|
||||||
versions = append(versions, v.ToAltSvc())
|
|
||||||
}
|
|
||||||
s.supportedVersionsAsString = strings.Join(versions, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenAndServeQUIC listens on the UDP network address addr and calls the
|
|
||||||
// handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is
|
|
||||||
// used when handler is nil.
|
|
||||||
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
|
|
||||||
server := &Server{
|
|
||||||
Server: &http.Server{
|
|
||||||
Addr: addr,
|
|
||||||
Handler: handler,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return server.ListenAndServeTLS(certFile, keyFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenAndServe listens on the given network address for both, TLS and QUIC
|
|
||||||
// connetions in parallel. It returns if one of the two returns an error.
|
|
||||||
// http.DefaultServeMux is used when handler is nil.
|
|
||||||
// The correct Alt-Svc headers for QUIC are set.
|
|
||||||
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
|
|
||||||
// Load certs
|
|
||||||
var err error
|
|
||||||
certs := make([]tls.Certificate, 1)
|
|
||||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// We currently only use the cert-related stuff from tls.Config,
|
|
||||||
// so we don't need to make a full copy.
|
|
||||||
config := &tls.Config{
|
|
||||||
Certificates: certs,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open the listeners
|
|
||||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer udpConn.Close()
|
|
||||||
|
|
||||||
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
tcpConn, err := net.ListenTCP("tcp", tcpAddr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer tcpConn.Close()
|
|
||||||
|
|
||||||
tlsConn := tls.NewListener(tcpConn, config)
|
|
||||||
defer tlsConn.Close()
|
|
||||||
|
|
||||||
// Start the servers
|
|
||||||
httpServer := &http.Server{
|
|
||||||
Addr: addr,
|
|
||||||
TLSConfig: config,
|
|
||||||
}
|
|
||||||
|
|
||||||
quicServer := &Server{
|
|
||||||
Server: httpServer,
|
|
||||||
}
|
|
||||||
|
|
||||||
if handler == nil {
|
|
||||||
handler = http.DefaultServeMux
|
|
||||||
}
|
|
||||||
httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
quicServer.SetQuicHeaders(w.Header())
|
|
||||||
handler.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
|
|
||||||
hErr := make(chan error)
|
|
||||||
qErr := make(chan error)
|
|
||||||
go func() {
|
|
||||||
hErr <- httpServer.Serve(tlsConn)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
qErr <- quicServer.Serve(udpConn)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-hErr:
|
|
||||||
quicServer.Close()
|
|
||||||
return err
|
|
||||||
case err := <-qErr:
|
|
||||||
// Cannot close the HTTP server or wait for requests to complete properly :/
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -16,19 +16,11 @@ type StreamID = protocol.StreamID
|
||||||
// A VersionNumber is a QUIC version number.
|
// A VersionNumber is a QUIC version number.
|
||||||
type VersionNumber = protocol.VersionNumber
|
type VersionNumber = protocol.VersionNumber
|
||||||
|
|
||||||
const (
|
|
||||||
// VersionGQUIC39 is gQUIC version 39.
|
|
||||||
VersionGQUIC39 = protocol.Version39
|
|
||||||
// VersionGQUIC43 is gQUIC version 43.
|
|
||||||
VersionGQUIC43 = protocol.Version43
|
|
||||||
// VersionGQUIC44 is gQUIC version 44.
|
|
||||||
VersionGQUIC44 = protocol.Version44
|
|
||||||
// VersionMilestone0_10_0 uses TLS
|
|
||||||
VersionMilestone0_10_0 = protocol.VersionMilestone0_10_0
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Cookie can be used to verify the ownership of the client address.
|
// A Cookie can be used to verify the ownership of the client address.
|
||||||
type Cookie = handshake.Cookie
|
type Cookie struct {
|
||||||
|
RemoteAddr string
|
||||||
|
SentTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// ConnectionState records basic details about the QUIC connection.
|
// ConnectionState records basic details about the QUIC connection.
|
||||||
type ConnectionState = handshake.ConnectionState
|
type ConnectionState = handshake.ConnectionState
|
||||||
|
@ -166,11 +158,7 @@ type Config struct {
|
||||||
// If not set, it uses all versions available.
|
// If not set, it uses all versions available.
|
||||||
// Warning: This API should not be considered stable and will change soon.
|
// Warning: This API should not be considered stable and will change soon.
|
||||||
Versions []VersionNumber
|
Versions []VersionNumber
|
||||||
// Ask the server to omit the connection ID sent in the Public Header.
|
// The length of the connection ID in bytes.
|
||||||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
|
||||||
// Currently only valid for the client.
|
|
||||||
RequestConnectionIDOmission bool
|
|
||||||
// The length of the connection ID in bytes. Only valid for IETF QUIC.
|
|
||||||
// It can be 0, or any value between 4 and 18.
|
// It can be 0, or any value between 4 and 18.
|
||||||
// If not set, the interpretation depends on where the Config is used:
|
// If not set, the interpretation depends on where the Config is used:
|
||||||
// If used for dialing an address, a 0 byte connection ID will be used.
|
// If used for dialing an address, a 0 byte connection ID will be used.
|
||||||
|
@ -200,13 +188,10 @@ type Config struct {
|
||||||
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
|
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
|
||||||
// If not set, it will default to 100.
|
// If not set, it will default to 100.
|
||||||
// If set to a negative value, it doesn't allow any bidirectional streams.
|
// If set to a negative value, it doesn't allow any bidirectional streams.
|
||||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
|
||||||
MaxIncomingStreams int
|
MaxIncomingStreams int
|
||||||
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
||||||
// This value doesn't have any effect in Google QUIC.
|
|
||||||
// If not set, it will default to 100.
|
// If not set, it will default to 100.
|
||||||
// If set to a negative value, it doesn't allow any unidirectional streams.
|
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
|
||||||
MaxIncomingUniStreams int
|
MaxIncomingUniStreams int
|
||||||
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
||||||
KeepAlive bool
|
KeepAlive bool
|
||||||
|
|
|
@ -27,11 +27,12 @@ type SentPacketHandler interface {
|
||||||
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
||||||
ShouldSendNumPackets() int
|
ShouldSendNumPackets() int
|
||||||
|
|
||||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
|
||||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||||
DequeuePacketForRetransmission() *Packet
|
DequeuePacketForRetransmission() *Packet
|
||||||
DequeueProbePacket() (*Packet, error)
|
DequeueProbePacket() (*Packet, error)
|
||||||
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
|
|
||||||
|
PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen)
|
||||||
|
PopPacketNumber() protocol.PacketNumber
|
||||||
|
|
||||||
GetAlarmTimeout() time.Time
|
GetAlarmTimeout() time.Time
|
||||||
OnAlarm() error
|
OnAlarm() error
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
package quic
|
package ackhandler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The packetNumberGenerator generates the packet number for the next packet
|
// The packetNumberGenerator generates the packet number for the next packet
|
||||||
|
@ -15,13 +16,17 @@ type packetNumberGenerator struct {
|
||||||
|
|
||||||
next protocol.PacketNumber
|
next protocol.PacketNumber
|
||||||
nextToSkip protocol.PacketNumber
|
nextToSkip protocol.PacketNumber
|
||||||
|
|
||||||
|
history []protocol.PacketNumber
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
|
func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
|
||||||
return &packetNumberGenerator{
|
g := &packetNumberGenerator{
|
||||||
next: initial,
|
next: initial,
|
||||||
averagePeriod: averagePeriod,
|
averagePeriod: averagePeriod,
|
||||||
}
|
}
|
||||||
|
g.generateNewSkip()
|
||||||
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetNumberGenerator) Peek() protocol.PacketNumber {
|
func (p *packetNumberGenerator) Peek() protocol.PacketNumber {
|
||||||
|
@ -35,6 +40,10 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber {
|
||||||
p.next++
|
p.next++
|
||||||
|
|
||||||
if p.next == p.nextToSkip {
|
if p.next == p.nextToSkip {
|
||||||
|
if len(p.history)+1 > protocol.MaxTrackedSkippedPackets {
|
||||||
|
p.history = p.history[1:]
|
||||||
|
}
|
||||||
|
p.history = append(p.history, p.next)
|
||||||
p.next++
|
p.next++
|
||||||
p.generateNewSkip()
|
p.generateNewSkip()
|
||||||
}
|
}
|
||||||
|
@ -42,28 +51,28 @@ func (p *packetNumberGenerator) Pop() protocol.PacketNumber {
|
||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packetNumberGenerator) generateNewSkip() error {
|
func (p *packetNumberGenerator) generateNewSkip() {
|
||||||
num, err := p.getRandomNumber()
|
num := p.getRandomNumber()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
skip := protocol.PacketNumber(num) * (p.averagePeriod - 1) / (math.MaxUint16 / 2)
|
skip := protocol.PacketNumber(num) * (p.averagePeriod - 1) / (math.MaxUint16 / 2)
|
||||||
// make sure that there are never two consecutive packet numbers that are skipped
|
// make sure that there are never two consecutive packet numbers that are skipped
|
||||||
p.nextToSkip = p.next + 2 + skip
|
p.nextToSkip = p.next + 2 + skip
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRandomNumber() generates a cryptographically secure random number between 0 and MaxUint16 (= 65535)
|
// getRandomNumber() generates a cryptographically secure random number between 0 and MaxUint16 (= 65535)
|
||||||
// The expectation value is 65535/2
|
// The expectation value is 65535/2
|
||||||
func (p *packetNumberGenerator) getRandomNumber() (uint16, error) {
|
func (p *packetNumberGenerator) getRandomNumber() uint16 {
|
||||||
b := make([]byte, 2)
|
b := make([]byte, 2)
|
||||||
_, err := rand.Read(b)
|
rand.Read(b) // ignore the error here
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
num := uint16(b[0])<<8 + uint16(b[1])
|
num := uint16(b[0])<<8 + uint16(b[1])
|
||||||
return num, nil
|
return num
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *packetNumberGenerator) Validate(ack *wire.AckFrame) bool {
|
||||||
|
for _, pn := range p.history {
|
||||||
|
if ack.AcksPacket(pn) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
|
@ -2,9 +2,9 @@ package ackhandler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// The receivedPacketHistory stores if a packet number has already been received.
|
// The receivedPacketHistory stores if a packet number has already been received.
|
||||||
|
|
2
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/retransmittable.go
generated
vendored
2
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/retransmittable.go
generated
vendored
|
@ -16,8 +16,6 @@ func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
|
||||||
// IsFrameRetransmittable returns true if the frame should be retransmitted.
|
// IsFrameRetransmittable returns true if the frame should be retransmitted.
|
||||||
func IsFrameRetransmittable(f wire.Frame) bool {
|
func IsFrameRetransmittable(f wire.Frame) bool {
|
||||||
switch f.(type) {
|
switch f.(type) {
|
||||||
case *wire.StopWaitingFrame:
|
|
||||||
return false
|
|
||||||
case *wire.AckFrame:
|
case *wire.AckFrame:
|
||||||
return false
|
return false
|
||||||
default:
|
default:
|
||||||
|
|
67
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
67
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
|
@ -8,9 +8,9 @@ import (
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -31,11 +31,12 @@ const (
|
||||||
|
|
||||||
type sentPacketHandler struct {
|
type sentPacketHandler struct {
|
||||||
lastSentPacketNumber protocol.PacketNumber
|
lastSentPacketNumber protocol.PacketNumber
|
||||||
|
packetNumberGenerator *packetNumberGenerator
|
||||||
|
|
||||||
lastSentRetransmittablePacketTime time.Time
|
lastSentRetransmittablePacketTime time.Time
|
||||||
lastSentHandshakePacketTime time.Time
|
lastSentHandshakePacketTime time.Time
|
||||||
|
|
||||||
nextPacketSendTime time.Time
|
nextPacketSendTime time.Time
|
||||||
skippedPackets []protocol.PacketNumber
|
|
||||||
|
|
||||||
largestAcked protocol.PacketNumber
|
largestAcked protocol.PacketNumber
|
||||||
largestReceivedPacketWithAck protocol.PacketNumber
|
largestReceivedPacketWithAck protocol.PacketNumber
|
||||||
|
@ -46,7 +47,6 @@ type sentPacketHandler struct {
|
||||||
largestSentBeforeRTO protocol.PacketNumber
|
largestSentBeforeRTO protocol.PacketNumber
|
||||||
|
|
||||||
packetHistory *sentPacketHistory
|
packetHistory *sentPacketHistory
|
||||||
stopWaitingManager stopWaitingManager
|
|
||||||
|
|
||||||
retransmissionQueue []*Packet
|
retransmissionQueue []*Packet
|
||||||
|
|
||||||
|
@ -90,8 +90,8 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
|
||||||
)
|
)
|
||||||
|
|
||||||
return &sentPacketHandler{
|
return &sentPacketHandler{
|
||||||
|
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
|
||||||
packetHistory: newSentPacketHistory(),
|
packetHistory: newSentPacketHistory(),
|
||||||
stopWaitingManager: stopWaitingManager{},
|
|
||||||
rttStats: rttStats,
|
rttStats: rttStats,
|
||||||
congestion: congestion,
|
congestion: congestion,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
@ -110,13 +110,13 @@ func (h *sentPacketHandler) SetHandshakeComplete() {
|
||||||
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
|
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
|
||||||
var queue []*Packet
|
var queue []*Packet
|
||||||
for _, packet := range h.retransmissionQueue {
|
for _, packet := range h.retransmissionQueue {
|
||||||
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
|
if packet.EncryptionLevel == protocol.Encryption1RTT {
|
||||||
queue = append(queue, packet)
|
queue = append(queue, packet)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var handshakePackets []*Packet
|
var handshakePackets []*Packet
|
||||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||||
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
|
if p.EncryptionLevel != protocol.Encryption1RTT {
|
||||||
handshakePackets = append(handshakePackets, p)
|
handshakePackets = append(handshakePackets, p)
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
|
@ -148,10 +148,7 @@ func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retra
|
||||||
|
|
||||||
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
|
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
|
||||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||||
h.skippedPackets = append(h.skippedPackets, p)
|
h.logger.Debugf("Skipping packet number %#x", p)
|
||||||
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
|
|
||||||
h.skippedPackets = h.skippedPackets[1:]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
h.lastSentPacketNumber = packet.PacketNumber
|
h.lastSentPacketNumber = packet.PacketNumber
|
||||||
|
@ -166,7 +163,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
|
||||||
isRetransmittable := len(packet.Frames) != 0
|
isRetransmittable := len(packet.Frames) != 0
|
||||||
|
|
||||||
if isRetransmittable {
|
if isRetransmittable {
|
||||||
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
|
if packet.EncryptionLevel != protocol.Encryption1RTT {
|
||||||
h.lastSentHandshakePacketTime = packet.SendTime
|
h.lastSentHandshakePacketTime = packet.SendTime
|
||||||
}
|
}
|
||||||
h.lastSentRetransmittablePacketTime = packet.SendTime
|
h.lastSentRetransmittablePacketTime = packet.SendTime
|
||||||
|
@ -198,7 +195,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
||||||
h.largestReceivedPacketWithAck = withPacketNumber
|
h.largestReceivedPacketWithAck = withPacketNumber
|
||||||
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
|
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
|
||||||
|
|
||||||
if h.skippedPacketsAcked(ackFrame) {
|
if !h.packetNumberGenerator.Validate(ackFrame) {
|
||||||
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -213,9 +210,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
||||||
|
|
||||||
priorInFlight := h.bytesInFlight
|
priorInFlight := h.bytesInFlight
|
||||||
for _, p := range ackedPackets {
|
for _, p := range ackedPackets {
|
||||||
if encLevel < p.EncryptionLevel {
|
// TODO(#1534): check the encryption level
|
||||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
|
// if encLevel < p.EncryptionLevel {
|
||||||
}
|
// return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
|
||||||
|
// }
|
||||||
|
|
||||||
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||||
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||||
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||||
|
@ -234,10 +233,6 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
h.updateLossDetectionAlarm()
|
h.updateLossDetectionAlarm()
|
||||||
|
|
||||||
h.garbageCollectSkippedPackets()
|
|
||||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -519,12 +514,13 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
|
||||||
return h.DequeuePacketForRetransmission(), nil
|
return h.DequeuePacketForRetransmission(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
|
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
|
pn := h.packetNumberGenerator.Peek()
|
||||||
|
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
||||||
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
return h.packetNumberGenerator.Pop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) SendMode() SendMode {
|
func (h *sentPacketHandler) SendMode() SendMode {
|
||||||
|
@ -585,7 +581,7 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
||||||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
|
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
|
||||||
var handshakePackets []*Packet
|
var handshakePackets []*Packet
|
||||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||||
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
|
if p.canBeRetransmitted && p.EncryptionLevel != protocol.Encryption1RTT {
|
||||||
handshakePackets = append(handshakePackets, p)
|
handshakePackets = append(handshakePackets, p)
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
|
@ -607,7 +603,6 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
h.retransmissionQueue = append(h.retransmissionQueue, p)
|
h.retransmissionQueue = append(h.retransmissionQueue, p)
|
||||||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -633,26 +628,6 @@ func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
||||||
}
|
}
|
||||||
rto = utils.MaxDuration(rto, minRTOTimeout)
|
rto = utils.MaxDuration(rto, minRTOTimeout)
|
||||||
// Exponential backoff
|
// Exponential backoff
|
||||||
rto = rto << h.rtoCount
|
rto <<= h.rtoCount
|
||||||
return utils.MinDuration(rto, maxRTOTimeout)
|
return utils.MinDuration(rto, maxRTOTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
|
|
||||||
for _, p := range h.skippedPackets {
|
|
||||||
if ackFrame.AcksPacket(p) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
|
||||||
lowestUnacked := h.lowestUnacked()
|
|
||||||
deleteIndex := 0
|
|
||||||
for i, p := range h.skippedPackets {
|
|
||||||
if p < lowestUnacked {
|
|
||||||
deleteIndex = i + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.skippedPackets = h.skippedPackets[deleteIndex:]
|
|
||||||
}
|
|
||||||
|
|
6
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
6
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
|
@ -35,7 +35,7 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
|
||||||
}
|
}
|
||||||
if p.canBeRetransmitted {
|
if p.canBeRetransmitted {
|
||||||
h.numOutstandingPackets++
|
h.numOutstandingPackets++
|
||||||
if p.EncryptionLevel < protocol.EncryptionForwardSecure {
|
if p.EncryptionLevel != protocol.Encryption1RTT {
|
||||||
h.numOutstandingHandshakePackets++
|
h.numOutstandingHandshakePackets++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -106,7 +106,7 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber)
|
||||||
if h.numOutstandingPackets < 0 {
|
if h.numOutstandingPackets < 0 {
|
||||||
panic("numOutstandingHandshakePackets negative")
|
panic("numOutstandingHandshakePackets negative")
|
||||||
}
|
}
|
||||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
|
||||||
h.numOutstandingHandshakePackets--
|
h.numOutstandingHandshakePackets--
|
||||||
if h.numOutstandingHandshakePackets < 0 {
|
if h.numOutstandingHandshakePackets < 0 {
|
||||||
panic("numOutstandingHandshakePackets negative")
|
panic("numOutstandingHandshakePackets negative")
|
||||||
|
@ -147,7 +147,7 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
|
||||||
if h.numOutstandingPackets < 0 {
|
if h.numOutstandingPackets < 0 {
|
||||||
panic("numOutstandingHandshakePackets negative")
|
panic("numOutstandingHandshakePackets negative")
|
||||||
}
|
}
|
||||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
|
||||||
h.numOutstandingHandshakePackets--
|
h.numOutstandingHandshakePackets--
|
||||||
if h.numOutstandingHandshakePackets < 0 {
|
if h.numOutstandingHandshakePackets < 0 {
|
||||||
panic("numOutstandingHandshakePackets negative")
|
panic("numOutstandingHandshakePackets negative")
|
||||||
|
|
43
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go
generated
vendored
43
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager.go
generated
vendored
|
@ -1,43 +0,0 @@
|
||||||
package ackhandler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
|
||||||
)
|
|
||||||
|
|
||||||
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
|
|
||||||
type stopWaitingManager struct {
|
|
||||||
largestLeastUnackedSent protocol.PacketNumber
|
|
||||||
nextLeastUnacked protocol.PacketNumber
|
|
||||||
|
|
||||||
lastStopWaitingFrame *wire.StopWaitingFrame
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
|
||||||
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
|
|
||||||
if force {
|
|
||||||
return s.lastStopWaitingFrame
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
s.largestLeastUnackedSent = s.nextLeastUnacked
|
|
||||||
swf := &wire.StopWaitingFrame{
|
|
||||||
LeastUnacked: s.nextLeastUnacked,
|
|
||||||
}
|
|
||||||
s.lastStopWaitingFrame = swf
|
|
||||||
return swf
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
|
|
||||||
largestAcked := ack.LargestAcked()
|
|
||||||
if largestAcked >= s.nextLeastUnacked {
|
|
||||||
s.nextLeastUnacked = largestAcked + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stopWaitingManager) QueuedRetransmissionForPacketNumber(p protocol.PacketNumber) {
|
|
||||||
if p >= s.nextLeastUnacked {
|
|
||||||
s.nextLeastUnacked = p + 1
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -193,7 +193,7 @@ func (c *cubicSender) OnPacketLost(
|
||||||
if c.congestionWindow >= 2*c.initialCongestionWindow {
|
if c.congestionWindow >= 2*c.initialCongestionWindow {
|
||||||
c.minSlowStartExitWindow = c.congestionWindow / 2
|
c.minSlowStartExitWindow = c.congestionWindow / 2
|
||||||
}
|
}
|
||||||
c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS
|
c.congestionWindow -= protocol.DefaultTCPMSS
|
||||||
} else if c.reno {
|
} else if c.reno {
|
||||||
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
|
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1,72 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/cipher"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/aes12"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
type aeadAESGCM12 struct {
|
|
||||||
otherIV []byte
|
|
||||||
myIV []byte
|
|
||||||
encrypter cipher.AEAD
|
|
||||||
decrypter cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ AEAD = &aeadAESGCM12{}
|
|
||||||
|
|
||||||
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
|
|
||||||
//
|
|
||||||
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
|
||||||
// tag size, and couples the cipher and aes packages closely.
|
|
||||||
// See https://github.com/lucas-clemente/aes12.
|
|
||||||
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
|
||||||
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
|
|
||||||
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
|
|
||||||
}
|
|
||||||
encrypterCipher, err := aes12.NewCipher(myKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
encrypter, err := aes12.NewGCM(encrypterCipher)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
decrypterCipher, err := aes12.NewCipher(otherKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
decrypter, err := aes12.NewGCM(decrypterCipher)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &aeadAESGCM12{
|
|
||||||
otherIV: otherIV,
|
|
||||||
myIV: myIV,
|
|
||||||
encrypter: encrypter,
|
|
||||||
decrypter: decrypter,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
|
||||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
|
||||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
|
||||||
res := make([]byte, 12)
|
|
||||||
copy(res[0:4], iv)
|
|
||||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadAESGCM12) Overhead() int {
|
|
||||||
return aead.encrypter.Overhead()
|
|
||||||
}
|
|
|
@ -1,48 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"hash/fnv"
|
|
||||||
|
|
||||||
"github.com/hashicorp/golang-lru"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
compressedCertsCache *lru.Cache
|
|
||||||
)
|
|
||||||
|
|
||||||
func getCompressedCert(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
|
||||||
// Hash all inputs
|
|
||||||
hasher := fnv.New64a()
|
|
||||||
for _, v := range chain {
|
|
||||||
hasher.Write(v)
|
|
||||||
}
|
|
||||||
hasher.Write(pCommonSetHashes)
|
|
||||||
hasher.Write(pCachedHashes)
|
|
||||||
hash := hasher.Sum64()
|
|
||||||
|
|
||||||
var result []byte
|
|
||||||
|
|
||||||
resultI, isCached := compressedCertsCache.Get(hash)
|
|
||||||
if isCached {
|
|
||||||
result = resultI.([]byte)
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
result, err = compressChain(chain, pCommonSetHashes, pCachedHashes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
compressedCertsCache.Add(hash, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
var err error
|
|
||||||
compressedCertsCache, err = lru.New(protocol.NumCachedCertificates)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("fatal error in quic-go: could not create lru cache: %s", err.Error()))
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,113 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A CertChain holds a certificate and a private key
|
|
||||||
type CertChain interface {
|
|
||||||
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
|
|
||||||
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
|
|
||||||
GetLeafCert(sni string) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// proofSource stores a key and a certificate for the server proof
|
|
||||||
type certChain struct {
|
|
||||||
config *tls.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ CertChain = &certChain{}
|
|
||||||
|
|
||||||
var errNoMatchingCertificate = errors.New("no matching certificate found")
|
|
||||||
|
|
||||||
// NewCertChain loads the key and cert from files
|
|
||||||
func NewCertChain(tlsConfig *tls.Config) CertChain {
|
|
||||||
return &certChain{config: tlsConfig}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignServerProof signs CHLO and server config for use in the server proof
|
|
||||||
func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
|
|
||||||
cert, err := c.getCertForSNI(sni)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return signServerProof(cert, chlo, serverConfigData)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
|
|
||||||
func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
|
||||||
cert, err := c.getCertForSNI(sni)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLeafCert gets the leaf certificate
|
|
||||||
func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
|
|
||||||
cert, err := c.getCertForSNI(sni)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return cert.Certificate[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
|
||||||
conf := c.config
|
|
||||||
conf, err := maybeGetConfigForClient(conf, sni)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// The rest of this function is mostly copied from crypto/tls.getCertificate
|
|
||||||
|
|
||||||
if conf.GetCertificate != nil {
|
|
||||||
cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
|
|
||||||
if cert != nil || err != nil {
|
|
||||||
return cert, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(conf.Certificates) == 0 {
|
|
||||||
return nil, errNoMatchingCertificate
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
|
|
||||||
// There's only one choice, so no point doing any work.
|
|
||||||
return &conf.Certificates[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
name := strings.ToLower(sni)
|
|
||||||
for len(name) > 0 && name[len(name)-1] == '.' {
|
|
||||||
name = name[:len(name)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
if cert, ok := conf.NameToCertificate[name]; ok {
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// try replacing labels in the name with wildcards until we get a
|
|
||||||
// match.
|
|
||||||
labels := strings.Split(name, ".")
|
|
||||||
for i := range labels {
|
|
||||||
labels[i] = "*"
|
|
||||||
candidate := strings.Join(labels, ".")
|
|
||||||
if cert, ok := conf.NameToCertificate[candidate]; ok {
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If nothing matches, return the first certificate.
|
|
||||||
return &conf.Certificates[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
|
|
||||||
if c.GetConfigForClient == nil {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
return c.GetConfigForClient(&tls.ClientHelloInfo{
|
|
||||||
ServerName: sni,
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,272 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"compress/flate"
|
|
||||||
"compress/zlib"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"hash/fnv"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type entryType uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
entryCompressed entryType = 1
|
|
||||||
entryCached entryType = 2
|
|
||||||
entryCommon entryType = 3
|
|
||||||
)
|
|
||||||
|
|
||||||
type entry struct {
|
|
||||||
t entryType
|
|
||||||
h uint64 // set hash
|
|
||||||
i uint32 // index
|
|
||||||
}
|
|
||||||
|
|
||||||
func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
|
||||||
res := &bytes.Buffer{}
|
|
||||||
|
|
||||||
cachedHashes, err := splitHashes(pCachedHashes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
setHashes, err := splitHashes(pCommonSetHashes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
chainHashes := make([]uint64, len(chain))
|
|
||||||
for i := range chain {
|
|
||||||
chainHashes[i] = HashCert(chain[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
entries := buildEntries(chain, chainHashes, cachedHashes, setHashes)
|
|
||||||
|
|
||||||
totalUncompressedLen := 0
|
|
||||||
for i, e := range entries {
|
|
||||||
res.WriteByte(uint8(e.t))
|
|
||||||
switch e.t {
|
|
||||||
case entryCached:
|
|
||||||
utils.LittleEndian.WriteUint64(res, e.h)
|
|
||||||
case entryCommon:
|
|
||||||
utils.LittleEndian.WriteUint64(res, e.h)
|
|
||||||
utils.LittleEndian.WriteUint32(res, e.i)
|
|
||||||
case entryCompressed:
|
|
||||||
totalUncompressedLen += 4 + len(chain[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
res.WriteByte(0) // end of list
|
|
||||||
|
|
||||||
if totalUncompressedLen > 0 {
|
|
||||||
gz, err := zlib.NewWriterLevelDict(res, flate.BestCompression, buildZlibDictForEntries(entries, chain))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen))
|
|
||||||
|
|
||||||
for i, e := range entries {
|
|
||||||
if e.t != entryCompressed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
lenCert := len(chain[i])
|
|
||||||
gz.Write([]byte{
|
|
||||||
byte(lenCert & 0xff),
|
|
||||||
byte((lenCert >> 8) & 0xff),
|
|
||||||
byte((lenCert >> 16) & 0xff),
|
|
||||||
byte((lenCert >> 24) & 0xff),
|
|
||||||
})
|
|
||||||
gz.Write(chain[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
gz.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return res.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decompressChain(data []byte) ([][]byte, error) {
|
|
||||||
var chain [][]byte
|
|
||||||
var entries []entry
|
|
||||||
r := bytes.NewReader(data)
|
|
||||||
|
|
||||||
var numCerts int
|
|
||||||
var hasCompressedCerts bool
|
|
||||||
for {
|
|
||||||
entryTypeByte, err := r.ReadByte()
|
|
||||||
if entryTypeByte == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
et := entryType(entryTypeByte)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
numCerts++
|
|
||||||
|
|
||||||
switch et {
|
|
||||||
case entryCached:
|
|
||||||
// we're not sending any certificate hashes in the CHLO, so there shouldn't be any cached certificates in the chain
|
|
||||||
return nil, errors.New("unexpected cached certificate")
|
|
||||||
case entryCommon:
|
|
||||||
e := entry{t: entryCommon}
|
|
||||||
e.h, err = utils.LittleEndian.ReadUint64(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
e.i, err = utils.LittleEndian.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
certSet, ok := certSets[e.h]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("unknown certSet")
|
|
||||||
}
|
|
||||||
if e.i >= uint32(len(certSet)) {
|
|
||||||
return nil, errors.New("certificate not found in certSet")
|
|
||||||
}
|
|
||||||
entries = append(entries, e)
|
|
||||||
chain = append(chain, certSet[e.i])
|
|
||||||
case entryCompressed:
|
|
||||||
hasCompressedCerts = true
|
|
||||||
entries = append(entries, entry{t: entryCompressed})
|
|
||||||
chain = append(chain, nil)
|
|
||||||
default:
|
|
||||||
return nil, errors.New("unknown entryType")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if numCerts == 0 {
|
|
||||||
return make([][]byte, 0), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasCompressedCerts {
|
|
||||||
uncompressedLength, err := utils.LittleEndian.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(4)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
zlibDict := buildZlibDictForEntries(entries, chain)
|
|
||||||
gz, err := zlib.NewReaderDict(r, zlibDict)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer gz.Close()
|
|
||||||
|
|
||||||
var totalLength uint32
|
|
||||||
var certIndex int
|
|
||||||
for totalLength < uncompressedLength {
|
|
||||||
lenBytes := make([]byte, 4)
|
|
||||||
_, err := gz.Read(lenBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
certLen := binary.LittleEndian.Uint32(lenBytes)
|
|
||||||
|
|
||||||
cert := make([]byte, certLen)
|
|
||||||
n, err := gz.Read(cert)
|
|
||||||
if uint32(n) != certLen && err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
if certIndex >= len(entries) {
|
|
||||||
return nil, errors.New("CertCompression BUG: no element to save uncompressed certificate")
|
|
||||||
}
|
|
||||||
if entries[certIndex].t == entryCompressed {
|
|
||||||
chain[certIndex] = cert
|
|
||||||
certIndex++
|
|
||||||
break
|
|
||||||
}
|
|
||||||
certIndex++
|
|
||||||
}
|
|
||||||
|
|
||||||
totalLength += 4 + certLen
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return chain, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry {
|
|
||||||
res := make([]entry, len(chain))
|
|
||||||
chainLoop:
|
|
||||||
for i := range chain {
|
|
||||||
// Check if hash is in cachedHashes
|
|
||||||
for j := range cachedHashes {
|
|
||||||
if chainHashes[i] == cachedHashes[j] {
|
|
||||||
res[i] = entry{t: entryCached, h: chainHashes[i]}
|
|
||||||
continue chainLoop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Go through common sets and check if it's in there
|
|
||||||
for _, setHash := range setHashes {
|
|
||||||
set, ok := certSets[setHash]
|
|
||||||
if !ok {
|
|
||||||
// We don't have this set
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// We have this set, check if chain[i] is in the set
|
|
||||||
pos := set.findCertInSet(chain[i])
|
|
||||||
if pos >= 0 {
|
|
||||||
// Found
|
|
||||||
res[i] = entry{t: entryCommon, h: setHash, i: uint32(pos)}
|
|
||||||
continue chainLoop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res[i] = entry{t: entryCompressed}
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildZlibDictForEntries(entries []entry, chain [][]byte) []byte {
|
|
||||||
var dict bytes.Buffer
|
|
||||||
|
|
||||||
// First the cached and common in reverse order
|
|
||||||
for i := len(entries) - 1; i >= 0; i-- {
|
|
||||||
if entries[i].t == entryCompressed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dict.Write(chain[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
dict.Write(certDictZlib)
|
|
||||||
return dict.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitHashes(hashes []byte) ([]uint64, error) {
|
|
||||||
if len(hashes)%8 != 0 {
|
|
||||||
return nil, errors.New("expected a multiple of 8 bytes for CCS / CCRT hashes")
|
|
||||||
}
|
|
||||||
n := len(hashes) / 8
|
|
||||||
res := make([]uint64, n)
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
res[i] = binary.LittleEndian.Uint64(hashes[i*8 : (i+1)*8])
|
|
||||||
}
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCommonCertificateHashes() []byte {
|
|
||||||
ccs := make([]byte, 8*len(certSets))
|
|
||||||
i := 0
|
|
||||||
for certSetHash := range certSets {
|
|
||||||
binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash)
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
return ccs
|
|
||||||
}
|
|
||||||
|
|
||||||
// HashCert calculates the FNV1a hash of a certificate
|
|
||||||
func HashCert(cert []byte) uint64 {
|
|
||||||
h := fnv.New64a()
|
|
||||||
h.Write(cert)
|
|
||||||
return h.Sum64()
|
|
||||||
}
|
|
|
@ -1,128 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
var certDictZlib = []byte{
|
|
||||||
0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04,
|
|
||||||
0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03,
|
|
||||||
0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x30,
|
|
||||||
0x5f, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x04, 0x01,
|
|
||||||
0x06, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86, 0xfd, 0x6d, 0x01, 0x07,
|
|
||||||
0x17, 0x01, 0x30, 0x33, 0x20, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65,
|
|
||||||
0x64, 0x20, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e,
|
|
||||||
0x20, 0x53, 0x20, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x34,
|
|
||||||
0x20, 0x53, 0x53, 0x4c, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31,
|
|
||||||
0x32, 0x20, 0x53, 0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x53, 0x65, 0x72,
|
|
||||||
0x76, 0x65, 0x72, 0x20, 0x43, 0x41, 0x30, 0x2d, 0x61, 0x69, 0x61, 0x2e,
|
|
||||||
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
|
|
||||||
0x2f, 0x45, 0x2d, 0x63, 0x72, 0x6c, 0x2e, 0x76, 0x65, 0x72, 0x69, 0x73,
|
|
||||||
0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x45, 0x2e, 0x63, 0x65,
|
|
||||||
0x72, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01,
|
|
||||||
0x01, 0x05, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x4a, 0x2e, 0x63,
|
|
||||||
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73,
|
|
||||||
0x2f, 0x63, 0x70, 0x73, 0x20, 0x28, 0x63, 0x29, 0x30, 0x30, 0x09, 0x06,
|
|
||||||
0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x30, 0x0d,
|
|
||||||
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05,
|
|
||||||
0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x7b, 0x30, 0x1d, 0x06, 0x03, 0x55,
|
|
||||||
0x1d, 0x0e, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86,
|
|
||||||
0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01,
|
|
||||||
0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xd2,
|
|
||||||
0x6f, 0x64, 0x6f, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x43, 0x2e,
|
|
||||||
0x63, 0x72, 0x6c, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16,
|
|
||||||
0x04, 0x14, 0xb4, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69,
|
|
||||||
0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x30, 0x0b, 0x06, 0x03,
|
|
||||||
0x55, 0x1d, 0x0f, 0x04, 0x04, 0x03, 0x02, 0x01, 0x30, 0x0d, 0x06, 0x09,
|
|
||||||
0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x30,
|
|
||||||
0x81, 0xca, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
|
|
||||||
0x02, 0x55, 0x53, 0x31, 0x10, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x08,
|
|
||||||
0x13, 0x07, 0x41, 0x72, 0x69, 0x7a, 0x6f, 0x6e, 0x61, 0x31, 0x13, 0x30,
|
|
||||||
0x11, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x0a, 0x53, 0x63, 0x6f, 0x74,
|
|
||||||
0x74, 0x73, 0x64, 0x61, 0x6c, 0x65, 0x31, 0x1a, 0x30, 0x18, 0x06, 0x03,
|
|
||||||
0x55, 0x04, 0x0a, 0x13, 0x11, 0x47, 0x6f, 0x44, 0x61, 0x64, 0x64, 0x79,
|
|
||||||
0x2e, 0x63, 0x6f, 0x6d, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31, 0x33,
|
|
||||||
0x30, 0x31, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x2a, 0x68, 0x74, 0x74,
|
|
||||||
0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63,
|
|
||||||
0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79,
|
|
||||||
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74,
|
|
||||||
0x6f, 0x72, 0x79, 0x31, 0x30, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x04, 0x03,
|
|
||||||
0x13, 0x27, 0x47, 0x6f, 0x20, 0x44, 0x61, 0x64, 0x64, 0x79, 0x20, 0x53,
|
|
||||||
0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66,
|
|
||||||
0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, 0x68,
|
|
||||||
0x6f, 0x72, 0x69, 0x74, 0x79, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55,
|
|
||||||
0x04, 0x05, 0x13, 0x08, 0x30, 0x37, 0x39, 0x36, 0x39, 0x32, 0x38, 0x37,
|
|
||||||
0x30, 0x1e, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d,
|
|
||||||
0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x0c,
|
|
||||||
0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x02, 0x30, 0x00,
|
|
||||||
0x30, 0x1d, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff,
|
|
||||||
0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55,
|
|
||||||
0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05,
|
|
||||||
0x05, 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07,
|
|
||||||
0x03, 0x02, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff,
|
|
||||||
0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x33, 0x06, 0x03, 0x55, 0x1d,
|
|
||||||
0x1f, 0x04, 0x2c, 0x30, 0x2a, 0x30, 0x28, 0xa0, 0x26, 0xa0, 0x24, 0x86,
|
|
||||||
0x22, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, 0x2e,
|
|
||||||
0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
|
|
||||||
0x67, 0x64, 0x73, 0x31, 0x2d, 0x32, 0x30, 0x2a, 0x30, 0x28, 0x06, 0x08,
|
|
||||||
0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, 0x16, 0x1c, 0x68, 0x74,
|
|
||||||
0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76, 0x65,
|
|
||||||
0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63,
|
|
||||||
0x70, 0x73, 0x30, 0x34, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x5a, 0x17,
|
|
||||||
0x0d, 0x31, 0x33, 0x30, 0x35, 0x30, 0x39, 0x06, 0x08, 0x2b, 0x06, 0x01,
|
|
||||||
0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x2d, 0x68, 0x74, 0x74, 0x70, 0x3a,
|
|
||||||
0x2f, 0x2f, 0x73, 0x30, 0x39, 0x30, 0x37, 0x06, 0x08, 0x2b, 0x06, 0x01,
|
|
||||||
0x05, 0x05, 0x07, 0x02, 0x30, 0x44, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04,
|
|
||||||
0x3d, 0x30, 0x3b, 0x30, 0x39, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86,
|
|
||||||
0xf8, 0x45, 0x01, 0x07, 0x17, 0x06, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03,
|
|
||||||
0x55, 0x04, 0x06, 0x13, 0x02, 0x47, 0x42, 0x31, 0x1b, 0x53, 0x31, 0x17,
|
|
||||||
0x30, 0x15, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0e, 0x56, 0x65, 0x72,
|
|
||||||
0x69, 0x53, 0x69, 0x67, 0x6e, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31,
|
|
||||||
0x1f, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x16, 0x56, 0x65,
|
|
||||||
0x72, 0x69, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x54, 0x72, 0x75, 0x73, 0x74,
|
|
||||||
0x20, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x31, 0x3b, 0x30, 0x39,
|
|
||||||
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x32, 0x54, 0x65, 0x72, 0x6d, 0x73,
|
|
||||||
0x20, 0x6f, 0x66, 0x20, 0x75, 0x73, 0x65, 0x20, 0x61, 0x74, 0x20, 0x68,
|
|
||||||
0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76,
|
|
||||||
0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
|
|
||||||
0x72, 0x70, 0x61, 0x20, 0x28, 0x63, 0x29, 0x30, 0x31, 0x10, 0x30, 0x0e,
|
|
||||||
0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x07, 0x53, 0x31, 0x13, 0x30, 0x11,
|
|
||||||
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x0a, 0x47, 0x31, 0x13, 0x30, 0x11,
|
|
||||||
0x06, 0x0b, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x3c, 0x02, 0x01,
|
|
||||||
0x03, 0x13, 0x02, 0x55, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04,
|
|
||||||
0x03, 0x14, 0x31, 0x19, 0x30, 0x17, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13,
|
|
||||||
0x31, 0x1d, 0x30, 0x1b, 0x06, 0x03, 0x55, 0x04, 0x0f, 0x13, 0x14, 0x50,
|
|
||||||
0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e,
|
|
||||||
0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x31, 0x12, 0x31, 0x21, 0x30,
|
|
||||||
0x1f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x18, 0x44, 0x6f, 0x6d, 0x61,
|
|
||||||
0x69, 0x6e, 0x20, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x20, 0x56,
|
|
||||||
0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x64, 0x31, 0x14, 0x31, 0x31,
|
|
||||||
0x30, 0x2f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x28, 0x53, 0x65, 0x65,
|
|
||||||
0x20, 0x77, 0x77, 0x77, 0x2e, 0x72, 0x3a, 0x2f, 0x2f, 0x73, 0x65, 0x63,
|
|
||||||
0x75, 0x72, 0x65, 0x2e, 0x67, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53,
|
|
||||||
0x69, 0x67, 0x6e, 0x31, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x41,
|
|
||||||
0x2e, 0x63, 0x72, 0x6c, 0x56, 0x65, 0x72, 0x69, 0x53, 0x69, 0x67, 0x6e,
|
|
||||||
0x20, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x20, 0x33, 0x20, 0x45, 0x63, 0x72,
|
|
||||||
0x6c, 0x2e, 0x67, 0x65, 0x6f, 0x74, 0x72, 0x75, 0x73, 0x74, 0x2e, 0x63,
|
|
||||||
0x6f, 0x6d, 0x2f, 0x63, 0x72, 0x6c, 0x73, 0x2f, 0x73, 0x64, 0x31, 0x1a,
|
|
||||||
0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x3a,
|
|
||||||
0x2f, 0x2f, 0x45, 0x56, 0x49, 0x6e, 0x74, 0x6c, 0x2d, 0x63, 0x63, 0x72,
|
|
||||||
0x74, 0x2e, 0x67, 0x77, 0x77, 0x77, 0x2e, 0x67, 0x69, 0x63, 0x65, 0x72,
|
|
||||||
0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x31, 0x6f, 0x63, 0x73, 0x70, 0x2e,
|
|
||||||
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
|
|
||||||
0x30, 0x39, 0x72, 0x61, 0x70, 0x69, 0x64, 0x73, 0x73, 0x6c, 0x2e, 0x63,
|
|
||||||
0x6f, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63,
|
|
||||||
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72,
|
|
||||||
0x79, 0x2f, 0x30, 0x81, 0x80, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05,
|
|
||||||
0x07, 0x01, 0x01, 0x04, 0x74, 0x30, 0x72, 0x30, 0x24, 0x06, 0x08, 0x2b,
|
|
||||||
0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, 0x18, 0x68, 0x74, 0x74,
|
|
||||||
0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, 0x73, 0x70, 0x2e, 0x67, 0x6f, 0x64,
|
|
||||||
0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x30, 0x4a, 0x06,
|
|
||||||
0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x3e, 0x68,
|
|
||||||
0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66,
|
|
||||||
0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64,
|
|
||||||
0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73,
|
|
||||||
0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, 0x67, 0x64, 0x5f, 0x69, 0x6e, 0x74,
|
|
||||||
0x65, 0x72, 0x6d, 0x65, 0x64, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, 0x72,
|
|
||||||
0x74, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16,
|
|
||||||
0x80, 0x14, 0xfd, 0xac, 0x61, 0x32, 0x93, 0x6c, 0x45, 0xd6, 0xe2, 0xee,
|
|
||||||
0x85, 0x5f, 0x9a, 0xba, 0xe7, 0x76, 0x99, 0x68, 0xcc, 0xe7, 0x30, 0x27,
|
|
||||||
0x86, 0x29, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x86, 0x30,
|
|
||||||
0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73,
|
|
||||||
}
|
|
|
@ -1,135 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"errors"
|
|
||||||
"hash/fnv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CertManager manages the certificates sent by the server
|
|
||||||
type CertManager interface {
|
|
||||||
SetData([]byte) error
|
|
||||||
GetCommonCertificateHashes() []byte
|
|
||||||
GetLeafCert() []byte
|
|
||||||
GetLeafCertHash() (uint64, error)
|
|
||||||
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
|
|
||||||
Verify(hostname string) error
|
|
||||||
GetChain() []*x509.Certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
type certManager struct {
|
|
||||||
chain []*x509.Certificate
|
|
||||||
config *tls.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ CertManager = &certManager{}
|
|
||||||
|
|
||||||
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded")
|
|
||||||
|
|
||||||
// NewCertManager creates a new CertManager
|
|
||||||
func NewCertManager(tlsConfig *tls.Config) CertManager {
|
|
||||||
return &certManager{config: tlsConfig}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
|
|
||||||
func (c *certManager) SetData(data []byte) error {
|
|
||||||
byteChain, err := decompressChain(data)
|
|
||||||
if err != nil {
|
|
||||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
chain := make([]*x509.Certificate, len(byteChain))
|
|
||||||
for i, data := range byteChain {
|
|
||||||
cert, err := x509.ParseCertificate(data)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
chain[i] = cert
|
|
||||||
}
|
|
||||||
|
|
||||||
c.chain = chain
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *certManager) GetChain() []*x509.Certificate {
|
|
||||||
return c.chain
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *certManager) GetCommonCertificateHashes() []byte {
|
|
||||||
return getCommonCertificateHashes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLeafCert returns the leaf certificate of the certificate chain
|
|
||||||
// it returns nil if the certificate chain has not yet been set
|
|
||||||
func (c *certManager) GetLeafCert() []byte {
|
|
||||||
if len(c.chain) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return c.chain[0].Raw
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
|
|
||||||
func (c *certManager) GetLeafCertHash() (uint64, error) {
|
|
||||||
leafCert := c.GetLeafCert()
|
|
||||||
if leafCert == nil {
|
|
||||||
return 0, errNoCertificateChain
|
|
||||||
}
|
|
||||||
|
|
||||||
h := fnv.New64a()
|
|
||||||
_, err := h.Write(leafCert)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return h.Sum64(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// VerifyServerProof verifies the signature of the server config
|
|
||||||
// it should only be called after the certificate chain has been set, otherwise it returns false
|
|
||||||
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool {
|
|
||||||
if len(c.chain) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify verifies the certificate chain
|
|
||||||
func (c *certManager) Verify(hostname string) error {
|
|
||||||
if len(c.chain) == 0 {
|
|
||||||
return errNoCertificateChain
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config != nil && c.config.InsecureSkipVerify {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
leafCert := c.chain[0]
|
|
||||||
|
|
||||||
var opts x509.VerifyOptions
|
|
||||||
if c.config != nil {
|
|
||||||
opts.Roots = c.config.RootCAs
|
|
||||||
if c.config.Time == nil {
|
|
||||||
opts.CurrentTime = time.Now()
|
|
||||||
} else {
|
|
||||||
opts.CurrentTime = c.config.Time()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
|
|
||||||
opts.DNSName = hostname
|
|
||||||
|
|
||||||
// the first certificate is the leaf certificate, all others are intermediates
|
|
||||||
if len(c.chain) > 1 {
|
|
||||||
intermediates := x509.NewCertPool()
|
|
||||||
for i := 1; i < len(c.chain); i++ {
|
|
||||||
intermediates.AddCert(c.chain[i])
|
|
||||||
}
|
|
||||||
opts.Intermediates = intermediates
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := leafCert.Verify(opts)
|
|
||||||
return err
|
|
||||||
}
|
|
|
@ -1,24 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go-certificates"
|
|
||||||
)
|
|
||||||
|
|
||||||
type certSet [][]byte
|
|
||||||
|
|
||||||
var certSets = map[uint64]certSet{
|
|
||||||
certsets.CertSet2Hash: certsets.CertSet2,
|
|
||||||
certsets.CertSet3Hash: certsets.CertSet3,
|
|
||||||
}
|
|
||||||
|
|
||||||
// findCertInSet searches for the cert in the set. Negative return value means not found.
|
|
||||||
func (s *certSet) findCertInSet(cert []byte) int {
|
|
||||||
for i, c := range *s {
|
|
||||||
if bytes.Equal(c, cert) {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
61
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go
generated
vendored
61
vendor/github.com/lucas-clemente/quic-go/internal/crypto/chacha20poly1305_aead.go
generated
vendored
|
@ -1,61 +0,0 @@
|
||||||
// +build ignore
|
|
||||||
|
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/cipher"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/aead/chacha20"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
type aeadChacha20Poly1305 struct {
|
|
||||||
otherIV []byte
|
|
||||||
myIV []byte
|
|
||||||
encrypter cipher.AEAD
|
|
||||||
decrypter cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305
|
|
||||||
func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
|
||||||
if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 {
|
|
||||||
return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs")
|
|
||||||
}
|
|
||||||
// copy because ChaCha20Poly1305 expects array pointers
|
|
||||||
var MyKey, OtherKey [32]byte
|
|
||||||
copy(MyKey[:], myKey)
|
|
||||||
copy(OtherKey[:], otherKey)
|
|
||||||
|
|
||||||
encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &aeadChacha20Poly1305{
|
|
||||||
otherIV: otherIV,
|
|
||||||
myIV: myIV,
|
|
||||||
encrypter: encrypter,
|
|
||||||
decrypter: decrypter,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
|
||||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
|
||||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
|
||||||
res := make([]byte, 12)
|
|
||||||
copy(res[0:4], iv)
|
|
||||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
|
||||||
return res
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/curve25519"
|
|
||||||
)
|
|
||||||
|
|
||||||
// KeyExchange manages the exchange of keys
|
|
||||||
type curve25519KEX struct {
|
|
||||||
secret [32]byte
|
|
||||||
public [32]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ KeyExchange = &curve25519KEX{}
|
|
||||||
|
|
||||||
// NewCurve25519KEX creates a new KeyExchange using Curve25519, see https://cr.yp.to/ecdh.html
|
|
||||||
func NewCurve25519KEX() (KeyExchange, error) {
|
|
||||||
c := &curve25519KEX{}
|
|
||||||
if _, err := rand.Read(c.secret[:]); err != nil {
|
|
||||||
return nil, errors.New("Curve25519: could not create private key")
|
|
||||||
}
|
|
||||||
curve25519.ScalarBaseMult(&c.public, &c.secret)
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *curve25519KEX) PublicKey() []byte {
|
|
||||||
return c.public[:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *curve25519KEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
|
|
||||||
if len(otherPublic) != 32 {
|
|
||||||
return nil, errors.New("Curve25519: expected public key of 32 byte")
|
|
||||||
}
|
|
||||||
var res [32]byte
|
|
||||||
var otherPublicArray [32]byte
|
|
||||||
copy(otherPublicArray[:], otherPublic)
|
|
||||||
curve25519.ScalarMult(&res, &c.secret, &otherPublicArray)
|
|
||||||
return res[:], nil
|
|
||||||
}
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/hmac"
|
||||||
|
"encoding/binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
// copied from https://github.com/cloudflare/tls-tris/blob/master/hkdf.go
|
||||||
|
func hkdfExtract(hash crypto.Hash, secret, salt []byte) []byte {
|
||||||
|
if salt == nil {
|
||||||
|
salt = make([]byte, hash.Size())
|
||||||
|
}
|
||||||
|
if secret == nil {
|
||||||
|
secret = make([]byte, hash.Size())
|
||||||
|
}
|
||||||
|
extractor := hmac.New(hash.New, salt)
|
||||||
|
extractor.Write(secret)
|
||||||
|
return extractor.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// copied from https://github.com/cloudflare/tls-tris/blob/master/hkdf.go
|
||||||
|
func hkdfExpand(hash crypto.Hash, prk, info []byte, l int) []byte {
|
||||||
|
var (
|
||||||
|
expander = hmac.New(hash.New, prk)
|
||||||
|
res = make([]byte, l)
|
||||||
|
counter = byte(1)
|
||||||
|
prev []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
if l > 255*expander.Size() {
|
||||||
|
panic("hkdf: requested too much output")
|
||||||
|
}
|
||||||
|
|
||||||
|
p := res
|
||||||
|
for len(p) > 0 {
|
||||||
|
expander.Reset()
|
||||||
|
expander.Write(prev)
|
||||||
|
expander.Write(info)
|
||||||
|
expander.Write([]byte{counter})
|
||||||
|
prev = expander.Sum(prev[:0])
|
||||||
|
counter++
|
||||||
|
n := copy(p, prev)
|
||||||
|
p = p[n:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// hkdfExpandLabel HKDF expands a label
|
||||||
|
func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, length int) []byte {
|
||||||
|
const prefix = "quic "
|
||||||
|
qlabel := make([]byte, 2 /* length */ +1 /* length of label */ +len(prefix)+len(label)+1 /* length of context (empty) */)
|
||||||
|
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
|
||||||
|
qlabel[2] = uint8(len(prefix) + len(label))
|
||||||
|
copy(qlabel[3:], []byte(prefix+label))
|
||||||
|
return hkdfExpand(hash, secret, qlabel, length)
|
||||||
|
}
|
|
@ -1,60 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"encoding/binary"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
clientExporterLabel = "EXPORTER-QUIC client 1rtt"
|
|
||||||
serverExporterLabel = "EXPORTER-QUIC server 1rtt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A TLSExporter gets the negotiated ciphersuite and computes exporter
|
|
||||||
type TLSExporter interface {
|
|
||||||
ConnectionState() mint.ConnectionState
|
|
||||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func qhkdfExpand(secret []byte, label string, length int) []byte {
|
|
||||||
qlabel := make([]byte, 2+1+5+len(label))
|
|
||||||
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
|
|
||||||
qlabel[2] = uint8(5 + len(label))
|
|
||||||
copy(qlabel[3:], []byte("QUIC "+label))
|
|
||||||
return mint.HkdfExpand(crypto.SHA256, secret, qlabel, length)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
|
|
||||||
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
|
|
||||||
var myLabel, otherLabel string
|
|
||||||
if pers == protocol.PerspectiveClient {
|
|
||||||
myLabel = clientExporterLabel
|
|
||||||
otherLabel = serverExporterLabel
|
|
||||||
} else {
|
|
||||||
myLabel = serverExporterLabel
|
|
||||||
otherLabel = clientExporterLabel
|
|
||||||
}
|
|
||||||
myKey, myIV, err := computeKeyAndIV(tls, myLabel)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
|
|
||||||
}
|
|
||||||
|
|
||||||
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
|
|
||||||
cs := tls.ConnectionState().CipherSuite
|
|
||||||
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
key = qhkdfExpand(secret, "key", cs.KeyLen)
|
|
||||||
iv = qhkdfExpand(secret, "iv", cs.IvLen)
|
|
||||||
return key, iv, nil
|
|
||||||
}
|
|
100
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go
generated
vendored
100
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation_quic_crypto.go
generated
vendored
|
@ -1,100 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/sha256"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/hkdf"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DeriveKeysChacha20 derives the client and server keys and creates a matching chacha20poly1305 AEAD instance
|
|
||||||
// func DeriveKeysChacha20(version protocol.VersionNumber, forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte) (AEAD, error) {
|
|
||||||
// otherKey, myKey, otherIV, myIV, err := deriveKeys(version, forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 32)
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
// return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance
|
|
||||||
func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
|
|
||||||
var swap bool
|
|
||||||
if pers == protocol.PerspectiveClient {
|
|
||||||
swap = true
|
|
||||||
}
|
|
||||||
otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16, swap)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV)
|
|
||||||
}
|
|
||||||
|
|
||||||
// deriveKeys derives the keys and the IVs
|
|
||||||
// swap should be set true if generating the values for the client, and false for the server
|
|
||||||
func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int, swap bool) ([]byte, []byte, []byte, []byte, error) {
|
|
||||||
var info bytes.Buffer
|
|
||||||
if forwardSecure {
|
|
||||||
info.Write([]byte("QUIC forward secure key expansion\x00"))
|
|
||||||
} else {
|
|
||||||
info.Write([]byte("QUIC key expansion\x00"))
|
|
||||||
}
|
|
||||||
info.Write(connID)
|
|
||||||
info.Write(chlo)
|
|
||||||
info.Write(scfg)
|
|
||||||
info.Write(cert)
|
|
||||||
|
|
||||||
r := hkdf.New(sha256.New, sharedSecret, nonces, info.Bytes())
|
|
||||||
|
|
||||||
s := make([]byte, 2*keyLen+2*4)
|
|
||||||
if _, err := io.ReadFull(r, s); err != nil {
|
|
||||||
return nil, nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
key1 := s[:keyLen]
|
|
||||||
key2 := s[keyLen : 2*keyLen]
|
|
||||||
iv1 := s[2*keyLen : 2*keyLen+4]
|
|
||||||
iv2 := s[2*keyLen+4:]
|
|
||||||
|
|
||||||
var otherKey, myKey []byte
|
|
||||||
var otherIV, myIV []byte
|
|
||||||
|
|
||||||
if !forwardSecure {
|
|
||||||
if err := diversify(key2, iv2, divNonce); err != nil {
|
|
||||||
return nil, nil, nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if swap {
|
|
||||||
otherKey = key2
|
|
||||||
myKey = key1
|
|
||||||
otherIV = iv2
|
|
||||||
myIV = iv1
|
|
||||||
} else {
|
|
||||||
otherKey = key1
|
|
||||||
myKey = key2
|
|
||||||
otherIV = iv1
|
|
||||||
myIV = iv2
|
|
||||||
}
|
|
||||||
|
|
||||||
return otherKey, myKey, otherIV, myIV, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func diversify(key, iv, divNonce []byte) error {
|
|
||||||
secret := make([]byte, len(key)+len(iv))
|
|
||||||
copy(secret, key)
|
|
||||||
copy(secret[len(key):], iv)
|
|
||||||
|
|
||||||
r := hkdf.New(sha256.New, secret, divNonce, []byte("QUIC key diversification"))
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, key); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := io.ReadFull(r, iv); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,7 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
// KeyExchange manages the exchange of keys
|
|
||||||
type KeyExchange interface {
|
|
||||||
PublicKey() []byte
|
|
||||||
CalculateSharedKey(otherPublic []byte) ([]byte, error)
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
// NewNullAEAD creates a NullAEAD
|
|
||||||
func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) {
|
|
||||||
if v.UsesTLS() {
|
|
||||||
return newNullAEADAESGCM(connID, p)
|
|
||||||
}
|
|
||||||
return &nullAEADFNV128a{perspective: p}, nil
|
|
||||||
}
|
|
|
@ -3,13 +3,13 @@ package crypto
|
||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
|
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
|
||||||
|
|
||||||
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
// NewNullAEAD creates a NullAEAD
|
||||||
|
func NewNullAEAD(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
||||||
clientSecret, serverSecret := computeSecrets(connectionID)
|
clientSecret, serverSecret := computeSecrets(connectionID)
|
||||||
|
|
||||||
var mySecret, otherSecret []byte
|
var mySecret, otherSecret []byte
|
||||||
|
@ -28,14 +28,14 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec
|
||||||
}
|
}
|
||||||
|
|
||||||
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
|
||||||
handshakeSecret := mint.HkdfExtract(crypto.SHA256, quicVersion1Salt, connID)
|
initialSecret := hkdfExtract(crypto.SHA256, connID, quicVersion1Salt)
|
||||||
clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
|
clientSecret = HkdfExpandLabel(crypto.SHA256, initialSecret, "client in", crypto.SHA256.Size())
|
||||||
serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
|
serverSecret = HkdfExpandLabel(crypto.SHA256, initialSecret, "server in", crypto.SHA256.Size())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
|
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
|
||||||
key = qhkdfExpand(secret, "key", 16)
|
key = HkdfExpandLabel(crypto.SHA256, secret, "key", 16)
|
||||||
iv = qhkdfExpand(secret, "iv", 12)
|
iv = HkdfExpandLabel(crypto.SHA256, secret, "iv", 12)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"hash/fnv"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// nullAEAD handles not-yet encrypted packets
|
|
||||||
type nullAEADFNV128a struct {
|
|
||||||
perspective protocol.Perspective
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ AEAD = &nullAEADFNV128a{}
|
|
||||||
|
|
||||||
// Open and verify the ciphertext
|
|
||||||
func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
|
||||||
if len(src) < 12 {
|
|
||||||
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
|
|
||||||
}
|
|
||||||
|
|
||||||
hash := fnv.New128a()
|
|
||||||
hash.Write(associatedData)
|
|
||||||
hash.Write(src[12:])
|
|
||||||
if n.perspective == protocol.PerspectiveServer {
|
|
||||||
hash.Write([]byte("Client"))
|
|
||||||
} else {
|
|
||||||
hash.Write([]byte("Server"))
|
|
||||||
}
|
|
||||||
sum := make([]byte, 0, 16)
|
|
||||||
sum = hash.Sum(sum)
|
|
||||||
// The tag is written in little endian, so we need to reverse the slice.
|
|
||||||
reverse(sum)
|
|
||||||
|
|
||||||
if !bytes.Equal(sum[:12], src[:12]) {
|
|
||||||
return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
|
|
||||||
}
|
|
||||||
return src[12:], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Seal writes hash and ciphertext to the buffer
|
|
||||||
func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
|
||||||
if cap(dst) < 12+len(src) {
|
|
||||||
dst = make([]byte, 12+len(src))
|
|
||||||
} else {
|
|
||||||
dst = dst[:12+len(src)]
|
|
||||||
}
|
|
||||||
|
|
||||||
hash := fnv.New128a()
|
|
||||||
hash.Write(associatedData)
|
|
||||||
hash.Write(src)
|
|
||||||
|
|
||||||
if n.perspective == protocol.PerspectiveServer {
|
|
||||||
hash.Write([]byte("Server"))
|
|
||||||
} else {
|
|
||||||
hash.Write([]byte("Client"))
|
|
||||||
}
|
|
||||||
sum := make([]byte, 0, 16)
|
|
||||||
sum = hash.Sum(sum)
|
|
||||||
// The tag is written in little endian, so we need to reverse the slice.
|
|
||||||
reverse(sum)
|
|
||||||
|
|
||||||
copy(dst[12:], src)
|
|
||||||
copy(dst, sum[:12])
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *nullAEADFNV128a) Overhead() int {
|
|
||||||
return 12
|
|
||||||
}
|
|
||||||
|
|
||||||
func reverse(a []byte) {
|
|
||||||
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
|
|
||||||
a[left], a[right] = a[right], a[left]
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,66 +0,0 @@
|
||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/asn1"
|
|
||||||
"errors"
|
|
||||||
"math/big"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ecdsaSignature struct {
|
|
||||||
R, S *big.Int
|
|
||||||
}
|
|
||||||
|
|
||||||
// signServerProof signs CHLO and server config for use in the server proof
|
|
||||||
func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) {
|
|
||||||
hash := sha256.New()
|
|
||||||
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
|
|
||||||
chloHash := sha256.Sum256(chlo)
|
|
||||||
hash.Write([]byte{32, 0, 0, 0})
|
|
||||||
hash.Write(chloHash[:])
|
|
||||||
hash.Write(serverConfigData)
|
|
||||||
|
|
||||||
key, ok := cert.PrivateKey.(crypto.Signer)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("expected PrivateKey to implement crypto.Signer")
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := crypto.SignerOpts(crypto.SHA256)
|
|
||||||
|
|
||||||
if _, ok = key.(*rsa.PrivateKey); ok {
|
|
||||||
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
|
|
||||||
}
|
|
||||||
|
|
||||||
return key.Sign(rand.Reader, hash.Sum(nil), opts)
|
|
||||||
}
|
|
||||||
|
|
||||||
// verifyServerProof verifies the server proof signature
|
|
||||||
func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool {
|
|
||||||
hash := sha256.New()
|
|
||||||
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
|
|
||||||
chloHash := sha256.Sum256(chlo)
|
|
||||||
hash.Write([]byte{32, 0, 0, 0})
|
|
||||||
hash.Write(chloHash[:])
|
|
||||||
hash.Write(serverConfigData)
|
|
||||||
|
|
||||||
// RSA
|
|
||||||
if cert.PublicKeyAlgorithm == x509.RSA {
|
|
||||||
opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
|
|
||||||
err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts)
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ECDSA
|
|
||||||
signature := &ecdsaSignature{}
|
|
||||||
rest, err := asn1.Unmarshal(proof, signature)
|
|
||||||
if err != nil || len(rest) != 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S)
|
|
||||||
}
|
|
|
@ -5,8 +5,8 @@ import (
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type connectionFlowController struct {
|
type connectionFlowController struct {
|
||||||
|
|
|
@ -19,7 +19,7 @@ type StreamFlowController interface {
|
||||||
flowController
|
flowController
|
||||||
// for receiving
|
// for receiving
|
||||||
// UpdateHighestReceived should be called when a new highest offset is received
|
// UpdateHighestReceived should be called when a new highest offset is received
|
||||||
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
|
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
|
||||||
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
22
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
22
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
|
@ -5,8 +5,8 @@ import (
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type streamFlowController struct {
|
type streamFlowController struct {
|
||||||
|
@ -17,7 +17,6 @@ type streamFlowController struct {
|
||||||
queueWindowUpdate func()
|
queueWindowUpdate func()
|
||||||
|
|
||||||
connection connectionFlowControllerI
|
connection connectionFlowControllerI
|
||||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
|
||||||
|
|
||||||
receivedFinalOffset bool
|
receivedFinalOffset bool
|
||||||
}
|
}
|
||||||
|
@ -27,7 +26,6 @@ var _ StreamFlowController = &streamFlowController{}
|
||||||
// NewStreamFlowController gets a new flow controller for a stream
|
// NewStreamFlowController gets a new flow controller for a stream
|
||||||
func NewStreamFlowController(
|
func NewStreamFlowController(
|
||||||
streamID protocol.StreamID,
|
streamID protocol.StreamID,
|
||||||
contributesToConnection bool,
|
|
||||||
cfc ConnectionFlowController,
|
cfc ConnectionFlowController,
|
||||||
receiveWindow protocol.ByteCount,
|
receiveWindow protocol.ByteCount,
|
||||||
maxReceiveWindow protocol.ByteCount,
|
maxReceiveWindow protocol.ByteCount,
|
||||||
|
@ -38,7 +36,6 @@ func NewStreamFlowController(
|
||||||
) StreamFlowController {
|
) StreamFlowController {
|
||||||
return &streamFlowController{
|
return &streamFlowController{
|
||||||
streamID: streamID,
|
streamID: streamID,
|
||||||
contributesToConnection: contributesToConnection,
|
|
||||||
connection: cfc.(connectionFlowControllerI),
|
connection: cfc.(connectionFlowControllerI),
|
||||||
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||||
baseFlowController: baseFlowController{
|
baseFlowController: baseFlowController{
|
||||||
|
@ -87,32 +84,21 @@ func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCou
|
||||||
if c.checkFlowControlViolation() {
|
if c.checkFlowControlViolation() {
|
||||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow))
|
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow))
|
||||||
}
|
}
|
||||||
if c.contributesToConnection {
|
|
||||||
return c.connection.IncrementHighestReceived(increment)
|
return c.connection.IncrementHighestReceived(increment)
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
|
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||||
c.baseFlowController.AddBytesRead(n)
|
c.baseFlowController.AddBytesRead(n)
|
||||||
if c.contributesToConnection {
|
|
||||||
c.connection.AddBytesRead(n)
|
c.connection.AddBytesRead(n)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||||
c.baseFlowController.AddBytesSent(n)
|
c.baseFlowController.AddBytesSent(n)
|
||||||
if c.contributesToConnection {
|
|
||||||
c.connection.AddBytesSent(n)
|
c.connection.AddBytesSent(n)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||||
window := c.baseFlowController.sendWindowSize()
|
return utils.MinByteCount(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
|
||||||
if c.contributesToConnection {
|
|
||||||
window = utils.MinByteCount(window, c.connection.SendWindowSize())
|
|
||||||
}
|
|
||||||
return window
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamFlowController) MaybeQueueWindowUpdate() {
|
func (c *streamFlowController) MaybeQueueWindowUpdate() {
|
||||||
|
@ -122,9 +108,7 @@ func (c *streamFlowController) MaybeQueueWindowUpdate() {
|
||||||
if hasWindowUpdate {
|
if hasWindowUpdate {
|
||||||
c.queueWindowUpdate()
|
c.queueWindowUpdate()
|
||||||
}
|
}
|
||||||
if c.contributesToConnection {
|
|
||||||
c.connection.MaybeQueueWindowUpdate()
|
c.connection.MaybeQueueWindowUpdate()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||||
|
@ -140,10 +124,8 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||||
offset := c.baseFlowController.getWindowUpdate()
|
offset := c.baseFlowController.getWindowUpdate()
|
||||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||||
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
|
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
|
||||||
if c.contributesToConnection {
|
|
||||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
c.mutex.Unlock()
|
c.mutex.Unlock()
|
||||||
return offset
|
return offset
|
||||||
}
|
}
|
||||||
|
|
58
vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go
generated
vendored
Normal file
58
vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go
generated
vendored
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/binary"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sealer struct {
|
||||||
|
iv []byte
|
||||||
|
aead cipher.AEAD
|
||||||
|
|
||||||
|
// use a single slice to avoid allocations
|
||||||
|
nonceBuf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Sealer = &sealer{}
|
||||||
|
|
||||||
|
func newSealer(aead cipher.AEAD, iv []byte) Sealer {
|
||||||
|
return &sealer{
|
||||||
|
iv: iv,
|
||||||
|
aead: aead,
|
||||||
|
nonceBuf: make([]byte, aead.NonceSize()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||||
|
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
|
||||||
|
return s.aead.Seal(dst, s.nonceBuf, src, ad)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sealer) Overhead() int {
|
||||||
|
return s.aead.Overhead()
|
||||||
|
}
|
||||||
|
|
||||||
|
type opener struct {
|
||||||
|
iv []byte
|
||||||
|
aead cipher.AEAD
|
||||||
|
|
||||||
|
// use a single slice to avoid allocations
|
||||||
|
nonceBuf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Opener = &opener{}
|
||||||
|
|
||||||
|
func newOpener(aead cipher.AEAD, iv []byte) Opener {
|
||||||
|
return &opener{
|
||||||
|
iv: iv,
|
||||||
|
aead: aead,
|
||||||
|
nonceBuf: make([]byte, aead.NonceSize()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||||
|
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
|
||||||
|
return o.aead.Open(dst, o.nonceBuf, src, ad)
|
||||||
|
}
|
24
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
24
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
generated
vendored
|
@ -5,6 +5,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -15,13 +17,16 @@ const (
|
||||||
// A Cookie is derived from the client address and can be used to verify the ownership of this address.
|
// A Cookie is derived from the client address and can be used to verify the ownership of this address.
|
||||||
type Cookie struct {
|
type Cookie struct {
|
||||||
RemoteAddr string
|
RemoteAddr string
|
||||||
// The time that the STK was issued (resolution 1 second)
|
OriginalDestConnectionID protocol.ConnectionID
|
||||||
|
// The time that the Cookie was issued (resolution 1 second)
|
||||||
SentTime time.Time
|
SentTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// token is the struct that is used for ASN1 serialization and deserialization
|
// token is the struct that is used for ASN1 serialization and deserialization
|
||||||
type token struct {
|
type token struct {
|
||||||
Data []byte
|
RemoteAddr []byte
|
||||||
|
OriginalDestConnectionID []byte
|
||||||
|
|
||||||
Timestamp int64
|
Timestamp int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,9 +47,10 @@ func NewCookieGenerator() (*CookieGenerator, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewToken generates a new Cookie for a given source address
|
// NewToken generates a new Cookie for a given source address
|
||||||
func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
func (g *CookieGenerator) NewToken(raddr net.Addr, origConnID protocol.ConnectionID) ([]byte, error) {
|
||||||
data, err := asn1.Marshal(token{
|
data, err := asn1.Marshal(token{
|
||||||
Data: encodeRemoteAddr(raddr),
|
RemoteAddr: encodeRemoteAddr(raddr),
|
||||||
|
OriginalDestConnectionID: origConnID,
|
||||||
Timestamp: time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -72,10 +78,14 @@ func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
|
||||||
if len(rest) != 0 {
|
if len(rest) != 0 {
|
||||||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
||||||
}
|
}
|
||||||
return &Cookie{
|
cookie := &Cookie{
|
||||||
RemoteAddr: decodeRemoteAddr(t.Data),
|
RemoteAddr: decodeRemoteAddr(t.RemoteAddr),
|
||||||
SentTime: time.Unix(t.Timestamp, 0),
|
SentTime: time.Unix(t.Timestamp, 0),
|
||||||
}, nil
|
}
|
||||||
|
if len(t.OriginalDestConnectionID) > 0 {
|
||||||
|
cookie.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID)
|
||||||
|
}
|
||||||
|
return cookie, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie
|
// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie
|
||||||
|
|
515
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go
generated
vendored
Normal file
515
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go
generated
vendored
Normal file
|
@ -0,0 +1,515 @@
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/marten-seemann/qtls"
|
||||||
|
)
|
||||||
|
|
||||||
|
type messageType uint8
|
||||||
|
|
||||||
|
// TLS handshake message types.
|
||||||
|
const (
|
||||||
|
typeClientHello messageType = 1
|
||||||
|
typeServerHello messageType = 2
|
||||||
|
typeEncryptedExtensions messageType = 8
|
||||||
|
typeCertificate messageType = 11
|
||||||
|
typeCertificateRequest messageType = 13
|
||||||
|
typeCertificateVerify messageType = 15
|
||||||
|
typeFinished messageType = 20
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m messageType) String() string {
|
||||||
|
switch m {
|
||||||
|
case typeClientHello:
|
||||||
|
return "ClientHello"
|
||||||
|
case typeServerHello:
|
||||||
|
return "ServerHello"
|
||||||
|
case typeEncryptedExtensions:
|
||||||
|
return "EncryptedExtensions"
|
||||||
|
case typeCertificate:
|
||||||
|
return "Certificate"
|
||||||
|
case typeCertificateRequest:
|
||||||
|
return "CertificateRequest"
|
||||||
|
case typeCertificateVerify:
|
||||||
|
return "CertificateVerify"
|
||||||
|
case typeFinished:
|
||||||
|
return "Finished"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown message type: %d", m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type cryptoSetup struct {
|
||||||
|
tlsConf *qtls.Config
|
||||||
|
|
||||||
|
messageChan chan []byte
|
||||||
|
|
||||||
|
readEncLevel protocol.EncryptionLevel
|
||||||
|
writeEncLevel protocol.EncryptionLevel
|
||||||
|
|
||||||
|
handleParamsCallback func(*TransportParameters)
|
||||||
|
|
||||||
|
// There are two ways that an error can occur during the handshake:
|
||||||
|
// 1. as a return value from qtls.Handshake()
|
||||||
|
// 2. when new data is passed to the crypto setup via HandleData()
|
||||||
|
// handshakeErrChan is closed when qtls.Handshake() errors
|
||||||
|
handshakeErrChan chan struct{}
|
||||||
|
// HandleData() sends errors on the messageErrChan
|
||||||
|
messageErrChan chan error
|
||||||
|
// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
|
||||||
|
handshakeDone chan struct{}
|
||||||
|
// transport parameters are sent on the receivedTransportParams, as soon as they are received
|
||||||
|
receivedTransportParams <-chan TransportParameters
|
||||||
|
// is closed when Close() is called
|
||||||
|
closeChan chan struct{}
|
||||||
|
|
||||||
|
clientHelloWritten bool
|
||||||
|
clientHelloWrittenChan chan struct{}
|
||||||
|
|
||||||
|
initialStream io.Writer
|
||||||
|
initialAEAD crypto.AEAD
|
||||||
|
|
||||||
|
handshakeStream io.Writer
|
||||||
|
handshakeOpener Opener
|
||||||
|
handshakeSealer Sealer
|
||||||
|
|
||||||
|
opener Opener
|
||||||
|
sealer Sealer
|
||||||
|
// TODO: add a 1-RTT stream (used for session tickets)
|
||||||
|
|
||||||
|
receivedWriteKey chan struct{}
|
||||||
|
receivedReadKey chan struct{}
|
||||||
|
|
||||||
|
logger utils.Logger
|
||||||
|
|
||||||
|
perspective protocol.Perspective
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ qtls.RecordLayer = &cryptoSetup{}
|
||||||
|
var _ CryptoSetup = &cryptoSetup{}
|
||||||
|
|
||||||
|
// NewCryptoSetupClient creates a new crypto setup for the client
|
||||||
|
func NewCryptoSetupClient(
|
||||||
|
initialStream io.Writer,
|
||||||
|
handshakeStream io.Writer,
|
||||||
|
origConnID protocol.ConnectionID,
|
||||||
|
connID protocol.ConnectionID,
|
||||||
|
params *TransportParameters,
|
||||||
|
handleParams func(*TransportParameters),
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
initialVersion protocol.VersionNumber,
|
||||||
|
supportedVersions []protocol.VersionNumber,
|
||||||
|
currentVersion protocol.VersionNumber,
|
||||||
|
logger utils.Logger,
|
||||||
|
perspective protocol.Perspective,
|
||||||
|
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
||||||
|
extHandler, receivedTransportParams := newExtensionHandlerClient(
|
||||||
|
params,
|
||||||
|
origConnID,
|
||||||
|
initialVersion,
|
||||||
|
supportedVersions,
|
||||||
|
currentVersion,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
return newCryptoSetup(
|
||||||
|
initialStream,
|
||||||
|
handshakeStream,
|
||||||
|
connID,
|
||||||
|
extHandler,
|
||||||
|
receivedTransportParams,
|
||||||
|
handleParams,
|
||||||
|
tlsConf,
|
||||||
|
logger,
|
||||||
|
perspective,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCryptoSetupServer creates a new crypto setup for the server
|
||||||
|
func NewCryptoSetupServer(
|
||||||
|
initialStream io.Writer,
|
||||||
|
handshakeStream io.Writer,
|
||||||
|
connID protocol.ConnectionID,
|
||||||
|
params *TransportParameters,
|
||||||
|
handleParams func(*TransportParameters),
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
supportedVersions []protocol.VersionNumber,
|
||||||
|
currentVersion protocol.VersionNumber,
|
||||||
|
logger utils.Logger,
|
||||||
|
perspective protocol.Perspective,
|
||||||
|
) (CryptoSetup, error) {
|
||||||
|
extHandler, receivedTransportParams := newExtensionHandlerServer(
|
||||||
|
params,
|
||||||
|
supportedVersions,
|
||||||
|
currentVersion,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
cs, _, err := newCryptoSetup(
|
||||||
|
initialStream,
|
||||||
|
handshakeStream,
|
||||||
|
connID,
|
||||||
|
extHandler,
|
||||||
|
receivedTransportParams,
|
||||||
|
handleParams,
|
||||||
|
tlsConf,
|
||||||
|
logger,
|
||||||
|
perspective,
|
||||||
|
)
|
||||||
|
return cs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCryptoSetup(
|
||||||
|
initialStream io.Writer,
|
||||||
|
handshakeStream io.Writer,
|
||||||
|
connID protocol.ConnectionID,
|
||||||
|
extHandler tlsExtensionHandler,
|
||||||
|
transportParamChan <-chan TransportParameters,
|
||||||
|
handleParams func(*TransportParameters),
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
logger utils.Logger,
|
||||||
|
perspective protocol.Perspective,
|
||||||
|
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
||||||
|
initialAEAD, err := crypto.NewNullAEAD(connID, perspective)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
cs := &cryptoSetup{
|
||||||
|
initialStream: initialStream,
|
||||||
|
initialAEAD: initialAEAD,
|
||||||
|
handshakeStream: handshakeStream,
|
||||||
|
readEncLevel: protocol.EncryptionInitial,
|
||||||
|
writeEncLevel: protocol.EncryptionInitial,
|
||||||
|
handleParamsCallback: handleParams,
|
||||||
|
receivedTransportParams: transportParamChan,
|
||||||
|
logger: logger,
|
||||||
|
perspective: perspective,
|
||||||
|
handshakeDone: make(chan struct{}),
|
||||||
|
handshakeErrChan: make(chan struct{}),
|
||||||
|
messageErrChan: make(chan error, 1),
|
||||||
|
clientHelloWrittenChan: make(chan struct{}),
|
||||||
|
messageChan: make(chan []byte, 100),
|
||||||
|
receivedReadKey: make(chan struct{}),
|
||||||
|
receivedWriteKey: make(chan struct{}),
|
||||||
|
closeChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
qtlsConf := tlsConfigToQtlsConfig(tlsConf)
|
||||||
|
qtlsConf.AlternativeRecordLayer = cs
|
||||||
|
qtlsConf.GetExtensions = extHandler.GetExtensions
|
||||||
|
qtlsConf.ReceivedExtensions = extHandler.ReceivedExtensions
|
||||||
|
cs.tlsConf = qtlsConf
|
||||||
|
return cs, cs.clientHelloWrittenChan, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) RunHandshake() error {
|
||||||
|
var conn *qtls.Conn
|
||||||
|
switch h.perspective {
|
||||||
|
case protocol.PerspectiveClient:
|
||||||
|
conn = qtls.Client(nil, h.tlsConf)
|
||||||
|
case protocol.PerspectiveServer:
|
||||||
|
conn = qtls.Server(nil, h.tlsConf)
|
||||||
|
}
|
||||||
|
// Handle errors that might occur when HandleData() is called.
|
||||||
|
handshakeErrChan := make(chan error, 1)
|
||||||
|
handshakeComplete := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(h.handshakeDone)
|
||||||
|
if err := conn.Handshake(); err != nil {
|
||||||
|
handshakeErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
close(handshakeComplete)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-h.closeChan:
|
||||||
|
close(h.messageChan)
|
||||||
|
// wait until the Handshake() go routine has returned
|
||||||
|
<-handshakeErrChan
|
||||||
|
return errors.New("Handshake aborted")
|
||||||
|
case <-handshakeComplete: // return when the handshake is done
|
||||||
|
return nil
|
||||||
|
case err := <-handshakeErrChan:
|
||||||
|
// if handleMessageFor{server,client} are waiting for some qtls action, make them return
|
||||||
|
close(h.handshakeErrChan)
|
||||||
|
return err
|
||||||
|
case err := <-h.messageErrChan:
|
||||||
|
// If the handshake errored because of an error that occurred during HandleData(),
|
||||||
|
// that error message will be more useful than the error message generated by Handshake().
|
||||||
|
// Close the message chan that qtls is receiving messages from.
|
||||||
|
// This will make qtls.Handshake() return.
|
||||||
|
// Thereby the go routine running qtls.Handshake() will return.
|
||||||
|
close(h.messageChan)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) Close() error {
|
||||||
|
close(h.closeChan)
|
||||||
|
// wait until qtls.Handshake() actually returned
|
||||||
|
<-h.handshakeDone
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMessage handles a TLS handshake message.
|
||||||
|
// It is called by the crypto streams when a new message is available.
|
||||||
|
// It returns if it is done with messages on the same encryption level.
|
||||||
|
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
|
||||||
|
msgType := messageType(data[0])
|
||||||
|
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
|
||||||
|
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
|
||||||
|
h.messageErrChan <- err
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
h.messageChan <- data
|
||||||
|
switch h.perspective {
|
||||||
|
case protocol.PerspectiveClient:
|
||||||
|
return h.handleMessageForClient(msgType)
|
||||||
|
case protocol.PerspectiveServer:
|
||||||
|
return h.handleMessageForServer(msgType)
|
||||||
|
default:
|
||||||
|
panic("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
||||||
|
var expected protocol.EncryptionLevel
|
||||||
|
switch msgType {
|
||||||
|
case typeClientHello,
|
||||||
|
typeServerHello:
|
||||||
|
expected = protocol.EncryptionInitial
|
||||||
|
case typeEncryptedExtensions,
|
||||||
|
typeCertificate,
|
||||||
|
typeCertificateRequest,
|
||||||
|
typeCertificateVerify,
|
||||||
|
typeFinished:
|
||||||
|
expected = protocol.EncryptionHandshake
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unexpected handshake message: %d", msgType)
|
||||||
|
}
|
||||||
|
if encLevel != expected {
|
||||||
|
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
|
||||||
|
switch msgType {
|
||||||
|
case typeClientHello:
|
||||||
|
select {
|
||||||
|
case params := <-h.receivedTransportParams:
|
||||||
|
h.handleParamsCallback(¶ms)
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// get the handshake write key
|
||||||
|
select {
|
||||||
|
case <-h.receivedWriteKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// get the 1-RTT write key
|
||||||
|
select {
|
||||||
|
case <-h.receivedWriteKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// get the handshake read key
|
||||||
|
// TODO: check that the initial stream doesn't have any more data
|
||||||
|
select {
|
||||||
|
case <-h.receivedReadKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case typeCertificate, typeCertificateVerify:
|
||||||
|
// nothing to do
|
||||||
|
return false
|
||||||
|
case typeFinished:
|
||||||
|
// get the 1-RTT read key
|
||||||
|
select {
|
||||||
|
case <-h.receivedReadKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
panic("unexpected handshake message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
|
||||||
|
switch msgType {
|
||||||
|
case typeServerHello:
|
||||||
|
// get the handshake read key
|
||||||
|
select {
|
||||||
|
case <-h.receivedReadKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case typeEncryptedExtensions:
|
||||||
|
select {
|
||||||
|
case params := <-h.receivedTransportParams:
|
||||||
|
h.handleParamsCallback(¶ms)
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
|
||||||
|
// nothing to do
|
||||||
|
return false
|
||||||
|
case typeFinished:
|
||||||
|
// get the handshake write key
|
||||||
|
select {
|
||||||
|
case <-h.receivedWriteKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// While the order of these two is not defined by the TLS spec,
|
||||||
|
// we have to do it on the same order as our TLS library does it.
|
||||||
|
// get the handshake write key
|
||||||
|
select {
|
||||||
|
case <-h.receivedWriteKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// get the 1-RTT read key
|
||||||
|
select {
|
||||||
|
case <-h.receivedReadKey:
|
||||||
|
case <-h.handshakeErrChan:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
panic("unexpected handshake message: ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadHandshakeMessage is called by TLS.
|
||||||
|
// It blocks until a new handshake message is available.
|
||||||
|
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
|
||||||
|
// TODO: add some error handling here (when the session is closed)
|
||||||
|
msg, ok := <-h.messageChan
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("error while handling the handshake message")
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||||
|
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
|
||||||
|
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
|
||||||
|
opener := newOpener(suite.AEAD(key, iv), iv)
|
||||||
|
|
||||||
|
switch h.readEncLevel {
|
||||||
|
case protocol.EncryptionInitial:
|
||||||
|
h.readEncLevel = protocol.EncryptionHandshake
|
||||||
|
h.handshakeOpener = opener
|
||||||
|
h.logger.Debugf("Installed Handshake Read keys")
|
||||||
|
case protocol.EncryptionHandshake:
|
||||||
|
h.readEncLevel = protocol.Encryption1RTT
|
||||||
|
h.opener = opener
|
||||||
|
h.logger.Debugf("Installed 1-RTT Read keys")
|
||||||
|
default:
|
||||||
|
panic("unexpected read encryption level")
|
||||||
|
}
|
||||||
|
h.receivedReadKey <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||||
|
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
|
||||||
|
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
|
||||||
|
sealer := newSealer(suite.AEAD(key, iv), iv)
|
||||||
|
|
||||||
|
switch h.writeEncLevel {
|
||||||
|
case protocol.EncryptionInitial:
|
||||||
|
h.writeEncLevel = protocol.EncryptionHandshake
|
||||||
|
h.handshakeSealer = sealer
|
||||||
|
h.logger.Debugf("Installed Handshake Write keys")
|
||||||
|
case protocol.EncryptionHandshake:
|
||||||
|
h.writeEncLevel = protocol.Encryption1RTT
|
||||||
|
h.sealer = sealer
|
||||||
|
h.logger.Debugf("Installed 1-RTT Write keys")
|
||||||
|
default:
|
||||||
|
panic("unexpected write encryption level")
|
||||||
|
}
|
||||||
|
h.receivedWriteKey <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteRecord is called when TLS writes data
|
||||||
|
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
|
||||||
|
switch h.writeEncLevel {
|
||||||
|
case protocol.EncryptionInitial:
|
||||||
|
// assume that the first WriteRecord call contains the ClientHello
|
||||||
|
n, err := h.initialStream.Write(p)
|
||||||
|
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
|
||||||
|
h.clientHelloWritten = true
|
||||||
|
close(h.clientHelloWrittenChan)
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
case protocol.EncryptionHandshake:
|
||||||
|
return h.handshakeStream.Write(p)
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unexpected write encryption level: %s", h.writeEncLevel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||||
|
if h.sealer != nil {
|
||||||
|
return protocol.Encryption1RTT, h.sealer
|
||||||
|
}
|
||||||
|
if h.handshakeSealer != nil {
|
||||||
|
return protocol.EncryptionHandshake, h.handshakeSealer
|
||||||
|
}
|
||||||
|
return protocol.EncryptionInitial, h.initialAEAD
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
|
||||||
|
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String())
|
||||||
|
|
||||||
|
switch level {
|
||||||
|
case protocol.EncryptionInitial:
|
||||||
|
return h.initialAEAD, nil
|
||||||
|
case protocol.EncryptionHandshake:
|
||||||
|
if h.handshakeSealer == nil {
|
||||||
|
return nil, errNoSealer
|
||||||
|
}
|
||||||
|
return h.handshakeSealer, nil
|
||||||
|
case protocol.Encryption1RTT:
|
||||||
|
if h.sealer == nil {
|
||||||
|
return nil, errNoSealer
|
||||||
|
}
|
||||||
|
return h.sealer, nil
|
||||||
|
default:
|
||||||
|
return nil, errNoSealer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||||
|
return h.initialAEAD.Open(dst, src, pn, ad)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||||
|
if h.handshakeOpener == nil {
|
||||||
|
return nil, errors.New("no handshake opener")
|
||||||
|
}
|
||||||
|
return h.handshakeOpener.Open(dst, src, pn, ad)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||||
|
if h.opener == nil {
|
||||||
|
return nil, errors.New("no 1-RTT opener")
|
||||||
|
}
|
||||||
|
return h.opener.Open(dst, src, pn, ad)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *cryptoSetup) ConnectionState() ConnectionState {
|
||||||
|
// TODO: return the connection state
|
||||||
|
return ConnectionState{}
|
||||||
|
}
|
543
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
543
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
generated
vendored
|
@ -1,543 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/tls"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
type cryptoSetupClient struct {
|
|
||||||
mutex sync.RWMutex
|
|
||||||
|
|
||||||
hostname string
|
|
||||||
connID protocol.ConnectionID
|
|
||||||
version protocol.VersionNumber
|
|
||||||
initialVersion protocol.VersionNumber
|
|
||||||
negotiatedVersions []protocol.VersionNumber
|
|
||||||
|
|
||||||
cryptoStream io.ReadWriter
|
|
||||||
|
|
||||||
serverConfig *serverConfigClient
|
|
||||||
|
|
||||||
stk []byte
|
|
||||||
sno []byte
|
|
||||||
nonc []byte
|
|
||||||
proof []byte
|
|
||||||
chloForSignature []byte
|
|
||||||
lastSentCHLO []byte
|
|
||||||
certManager crypto.CertManager
|
|
||||||
|
|
||||||
divNonceChan chan struct{}
|
|
||||||
diversificationNonce []byte
|
|
||||||
|
|
||||||
clientHelloCounter int
|
|
||||||
serverVerified bool // has the certificate chain and the proof already been verified
|
|
||||||
keyDerivation QuicCryptoKeyDerivationFunction
|
|
||||||
|
|
||||||
receivedSecurePacket bool
|
|
||||||
nullAEAD crypto.AEAD
|
|
||||||
secureAEAD crypto.AEAD
|
|
||||||
forwardSecureAEAD crypto.AEAD
|
|
||||||
|
|
||||||
paramsChan chan<- TransportParameters
|
|
||||||
handshakeEvent chan<- struct{}
|
|
||||||
|
|
||||||
params *TransportParameters
|
|
||||||
|
|
||||||
logger utils.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ CryptoSetup = &cryptoSetupClient{}
|
|
||||||
|
|
||||||
var (
|
|
||||||
errNoObitForClientNonce = errors.New("CryptoSetup BUG: No OBIT for client nonce available")
|
|
||||||
errClientNonceAlreadyExists = errors.New("CryptoSetup BUG: A client nonce was already generated")
|
|
||||||
errConflictingDiversificationNonces = errors.New("Received two different diversification nonces")
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
|
|
||||||
func NewCryptoSetupClient(
|
|
||||||
cryptoStream io.ReadWriter,
|
|
||||||
hostname string,
|
|
||||||
connID protocol.ConnectionID,
|
|
||||||
version protocol.VersionNumber,
|
|
||||||
tlsConfig *tls.Config,
|
|
||||||
params *TransportParameters,
|
|
||||||
paramsChan chan<- TransportParameters,
|
|
||||||
handshakeEvent chan<- struct{},
|
|
||||||
initialVersion protocol.VersionNumber,
|
|
||||||
negotiatedVersions []protocol.VersionNumber,
|
|
||||||
logger utils.Logger,
|
|
||||||
) (CryptoSetup, error) {
|
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
divNonceChan := make(chan struct{})
|
|
||||||
cs := &cryptoSetupClient{
|
|
||||||
cryptoStream: cryptoStream,
|
|
||||||
hostname: hostname,
|
|
||||||
connID: connID,
|
|
||||||
version: version,
|
|
||||||
certManager: crypto.NewCertManager(tlsConfig),
|
|
||||||
params: params,
|
|
||||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
|
||||||
nullAEAD: nullAEAD,
|
|
||||||
paramsChan: paramsChan,
|
|
||||||
handshakeEvent: handshakeEvent,
|
|
||||||
initialVersion: initialVersion,
|
|
||||||
// The server might have sent greased versions in the Version Negotiation packet.
|
|
||||||
// We need strip those from the list, since they won't be included in the handshake tag.
|
|
||||||
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
|
|
||||||
divNonceChan: divNonceChan,
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
return cs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
|
||||||
messageChan := make(chan HandshakeMessage)
|
|
||||||
errorChan := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
message, err := ParseHandshakeMessage(h.cryptoStream)
|
|
||||||
if err != nil {
|
|
||||||
errorChan <- qerr.Error(qerr.HandshakeFailed, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
messageChan <- message
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
if err := h.maybeUpgradeCrypto(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.mutex.RLock()
|
|
||||||
sendCHLO := h.secureAEAD == nil
|
|
||||||
h.mutex.RUnlock()
|
|
||||||
if sendCHLO {
|
|
||||||
if err := h.sendCHLO(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var message HandshakeMessage
|
|
||||||
select {
|
|
||||||
case <-h.divNonceChan:
|
|
||||||
// there's no message to process, but we should try upgrading the crypto again
|
|
||||||
continue
|
|
||||||
case message = <-messageChan:
|
|
||||||
case err := <-errorChan:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.logger.Debugf("Got %s", message)
|
|
||||||
switch message.Tag {
|
|
||||||
case TagREJ:
|
|
||||||
if err := h.handleREJMessage(message.Data); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case TagSHLO:
|
|
||||||
params, err := h.handleSHLOMessage(message.Data)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// blocks until the session has received the parameters
|
|
||||||
h.paramsChan <- *params
|
|
||||||
h.handshakeEvent <- struct{}{}
|
|
||||||
close(h.handshakeEvent)
|
|
||||||
default:
|
|
||||||
return qerr.InvalidCryptoMessageType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if stk, ok := cryptoData[TagSTK]; ok {
|
|
||||||
h.stk = stk
|
|
||||||
}
|
|
||||||
|
|
||||||
if sno, ok := cryptoData[TagSNO]; ok {
|
|
||||||
h.sno = sno
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: what happens if the server sends a different server config in two packets?
|
|
||||||
if scfg, ok := cryptoData[TagSCFG]; ok {
|
|
||||||
h.serverConfig, err = parseServerConfig(scfg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.serverConfig.IsExpired() {
|
|
||||||
return qerr.CryptoServerConfigExpired
|
|
||||||
}
|
|
||||||
|
|
||||||
// now that we have a server config, we can use its OBIT value to generate a client nonce
|
|
||||||
if len(h.nonc) == 0 {
|
|
||||||
err = h.generateClientNonce()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if proof, ok := cryptoData[TagPROF]; ok {
|
|
||||||
h.proof = proof
|
|
||||||
h.chloForSignature = h.lastSentCHLO
|
|
||||||
}
|
|
||||||
|
|
||||||
if crt, ok := cryptoData[TagCERT]; ok {
|
|
||||||
err := h.certManager.SetData(crt)
|
|
||||||
if err != nil {
|
|
||||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = h.certManager.Verify(h.hostname)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Infof("Certificate validation failed: %s", err.Error())
|
|
||||||
return qerr.ProofInvalid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
|
|
||||||
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
|
|
||||||
if !validProof {
|
|
||||||
h.logger.Infof("Server proof verification failed")
|
|
||||||
return qerr.ProofInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
h.serverVerified = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
if !h.receivedSecurePacket {
|
|
||||||
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
|
|
||||||
}
|
|
||||||
|
|
||||||
if sno, ok := cryptoData[TagSNO]; ok {
|
|
||||||
h.sno = sno
|
|
||||||
}
|
|
||||||
|
|
||||||
serverPubs, ok := cryptoData[TagPUBS]
|
|
||||||
if !ok {
|
|
||||||
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
|
||||||
}
|
|
||||||
|
|
||||||
verTag, ok := cryptoData[TagVER]
|
|
||||||
if !ok {
|
|
||||||
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
|
|
||||||
}
|
|
||||||
if !h.validateVersionList(verTag) {
|
|
||||||
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
|
||||||
}
|
|
||||||
|
|
||||||
nonce := append(h.nonc, h.sno...)
|
|
||||||
|
|
||||||
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
leafCert := h.certManager.GetLeafCert()
|
|
||||||
|
|
||||||
h.forwardSecureAEAD, err = h.keyDerivation(
|
|
||||||
true,
|
|
||||||
ephermalSharedSecret,
|
|
||||||
nonce,
|
|
||||||
h.connID,
|
|
||||||
h.lastSentCHLO,
|
|
||||||
h.serverConfig.Get(),
|
|
||||||
leafCert,
|
|
||||||
nil,
|
|
||||||
protocol.PerspectiveClient,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
|
|
||||||
|
|
||||||
params, err := readHelloMap(cryptoData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, qerr.InvalidCryptoMessageParameter
|
|
||||||
}
|
|
||||||
return params, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
|
|
||||||
numNegotiatedVersions := len(h.negotiatedVersions)
|
|
||||||
if numNegotiatedVersions == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
b := bytes.NewReader(verTags)
|
|
||||||
for i := 0; i < numNegotiatedVersions; i++ {
|
|
||||||
v, err := utils.BigEndian.ReadUint32(b)
|
|
||||||
if err != nil { // should never occur, since the length was already checked
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if protocol.VersionNumber(v) != h.negotiatedVersions[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
if h.forwardSecureAEAD != nil {
|
|
||||||
data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
|
|
||||||
if err == nil {
|
|
||||||
return data, protocol.EncryptionForwardSecure, nil
|
|
||||||
}
|
|
||||||
return nil, protocol.EncryptionUnspecified, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.secureAEAD != nil {
|
|
||||||
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
|
||||||
if err == nil {
|
|
||||||
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
|
||||||
h.receivedSecurePacket = true
|
|
||||||
return data, protocol.EncryptionSecure, nil
|
|
||||||
}
|
|
||||||
if h.receivedSecurePacket {
|
|
||||||
return nil, protocol.EncryptionUnspecified, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, protocol.EncryptionUnspecified, err
|
|
||||||
}
|
|
||||||
return res, protocol.EncryptionUnencrypted, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
if h.forwardSecureAEAD != nil {
|
|
||||||
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
|
|
||||||
} else if h.secureAEAD != nil {
|
|
||||||
return protocol.EncryptionSecure, h.secureAEAD
|
|
||||||
} else {
|
|
||||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
|
||||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
switch encLevel {
|
|
||||||
case protocol.EncryptionUnencrypted:
|
|
||||||
return h.nullAEAD, nil
|
|
||||||
case protocol.EncryptionSecure:
|
|
||||||
if h.secureAEAD == nil {
|
|
||||||
return nil, errors.New("CryptoSetupClient: no secureAEAD")
|
|
||||||
}
|
|
||||||
return h.secureAEAD, nil
|
|
||||||
case protocol.EncryptionForwardSecure:
|
|
||||||
if h.forwardSecureAEAD == nil {
|
|
||||||
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
|
|
||||||
}
|
|
||||||
return h.forwardSecureAEAD, nil
|
|
||||||
}
|
|
||||||
return nil, errors.New("CryptoSetupClient: no encryption level specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
return ConnectionState{
|
|
||||||
HandshakeComplete: h.forwardSecureAEAD != nil,
|
|
||||||
PeerCertificates: h.certManager.GetChain(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
|
|
||||||
h.mutex.Lock()
|
|
||||||
if len(h.diversificationNonce) > 0 {
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
if !bytes.Equal(h.diversificationNonce, divNonce) {
|
|
||||||
return errConflictingDiversificationNonces
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
h.diversificationNonce = divNonce
|
|
||||||
h.mutex.Unlock()
|
|
||||||
h.divNonceChan <- struct{}{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) sendCHLO() error {
|
|
||||||
h.clientHelloCounter++
|
|
||||||
if h.clientHelloCounter > protocol.MaxClientHellos {
|
|
||||||
return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos))
|
|
||||||
}
|
|
||||||
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
|
|
||||||
tags, err := h.getTags()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
h.addPadding(tags)
|
|
||||||
message := HandshakeMessage{
|
|
||||||
Tag: TagCHLO,
|
|
||||||
Data: tags,
|
|
||||||
}
|
|
||||||
|
|
||||||
h.logger.Debugf("Sending %s", message)
|
|
||||||
message.Write(b)
|
|
||||||
|
|
||||||
_, err = h.cryptoStream.Write(b.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.lastSentCHLO = b.Bytes()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
|
||||||
tags := h.params.getHelloMap()
|
|
||||||
tags[TagSNI] = []byte(h.hostname)
|
|
||||||
tags[TagPDMD] = []byte("X509")
|
|
||||||
|
|
||||||
ccs := h.certManager.GetCommonCertificateHashes()
|
|
||||||
if len(ccs) > 0 {
|
|
||||||
tags[TagCCS] = ccs
|
|
||||||
}
|
|
||||||
|
|
||||||
versionTag := make([]byte, 4)
|
|
||||||
binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion))
|
|
||||||
tags[TagVER] = versionTag
|
|
||||||
|
|
||||||
if len(h.stk) > 0 {
|
|
||||||
tags[TagSTK] = h.stk
|
|
||||||
}
|
|
||||||
if len(h.sno) > 0 {
|
|
||||||
tags[TagSNO] = h.sno
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.serverConfig != nil {
|
|
||||||
tags[TagSCID] = h.serverConfig.ID
|
|
||||||
|
|
||||||
leafCert := h.certManager.GetLeafCert()
|
|
||||||
if leafCert != nil {
|
|
||||||
certHash, _ := h.certManager.GetLeafCertHash()
|
|
||||||
xlct := make([]byte, 8)
|
|
||||||
binary.LittleEndian.PutUint64(xlct, certHash)
|
|
||||||
|
|
||||||
tags[TagNONC] = h.nonc
|
|
||||||
tags[TagXLCT] = xlct
|
|
||||||
tags[TagKEXS] = []byte("C255")
|
|
||||||
tags[TagAEAD] = []byte("AESG")
|
|
||||||
tags[TagPUBS] = h.serverConfig.kex.PublicKey() // TODO: check if 3 bytes need to be prepended
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tags, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// add a TagPAD to a tagMap, such that the total size will be bigger than the ClientHelloMinimumSize
|
|
||||||
func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
|
|
||||||
var size int
|
|
||||||
for _, tag := range tags {
|
|
||||||
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
|
||||||
}
|
|
||||||
paddingSize := protocol.MinClientHelloSize - size
|
|
||||||
if paddingSize > 0 {
|
|
||||||
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
|
|
||||||
if !h.serverVerified {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
leafCert := h.certManager.GetLeafCert()
|
|
||||||
if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) {
|
|
||||||
var err error
|
|
||||||
var nonce []byte
|
|
||||||
if h.sno == nil {
|
|
||||||
nonce = h.nonc
|
|
||||||
} else {
|
|
||||||
nonce = append(h.nonc, h.sno...)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.secureAEAD, err = h.keyDerivation(
|
|
||||||
false,
|
|
||||||
h.serverConfig.sharedSecret,
|
|
||||||
nonce,
|
|
||||||
h.connID,
|
|
||||||
h.lastSentCHLO,
|
|
||||||
h.serverConfig.Get(),
|
|
||||||
leafCert,
|
|
||||||
h.diversificationNonce,
|
|
||||||
protocol.PerspectiveClient,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
h.logger.Debugf("Creating AEAD for secure encryption.")
|
|
||||||
h.handshakeEvent <- struct{}{}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) generateClientNonce() error {
|
|
||||||
if len(h.nonc) > 0 {
|
|
||||||
return errClientNonceAlreadyExists
|
|
||||||
}
|
|
||||||
|
|
||||||
nonc := make([]byte, 32)
|
|
||||||
binary.BigEndian.PutUint32(nonc, uint32(time.Now().Unix()))
|
|
||||||
|
|
||||||
if len(h.serverConfig.obit) != 8 {
|
|
||||||
return errNoObitForClientNonce
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(nonc[4:12], h.serverConfig.obit)
|
|
||||||
|
|
||||||
_, err := rand.Read(nonc[12:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.nonc = nonc
|
|
||||||
return nil
|
|
||||||
}
|
|
467
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
467
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_server.go
generated
vendored
|
@ -1,467 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// QuicCryptoKeyDerivationFunction is used for key derivation
|
|
||||||
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
|
|
||||||
|
|
||||||
// KeyExchangeFunction is used to make a new KEX
|
|
||||||
type KeyExchangeFunction func() (crypto.KeyExchange, error)
|
|
||||||
|
|
||||||
// The CryptoSetupServer handles all things crypto for the Session
|
|
||||||
type cryptoSetupServer struct {
|
|
||||||
mutex sync.RWMutex
|
|
||||||
|
|
||||||
connID protocol.ConnectionID
|
|
||||||
remoteAddr net.Addr
|
|
||||||
scfg *ServerConfig
|
|
||||||
diversificationNonce []byte
|
|
||||||
|
|
||||||
version protocol.VersionNumber
|
|
||||||
supportedVersions []protocol.VersionNumber
|
|
||||||
|
|
||||||
acceptSTKCallback func(net.Addr, *Cookie) bool
|
|
||||||
|
|
||||||
nullAEAD crypto.AEAD
|
|
||||||
secureAEAD crypto.AEAD
|
|
||||||
forwardSecureAEAD crypto.AEAD
|
|
||||||
receivedForwardSecurePacket bool
|
|
||||||
receivedSecurePacket bool
|
|
||||||
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
|
|
||||||
|
|
||||||
receivedParams bool
|
|
||||||
paramsChan chan<- TransportParameters
|
|
||||||
handshakeEvent chan<- struct{}
|
|
||||||
|
|
||||||
keyDerivation QuicCryptoKeyDerivationFunction
|
|
||||||
keyExchange KeyExchangeFunction
|
|
||||||
|
|
||||||
cryptoStream io.ReadWriter
|
|
||||||
|
|
||||||
params *TransportParameters
|
|
||||||
|
|
||||||
sni string // need to fill out the ConnectionState
|
|
||||||
|
|
||||||
logger utils.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ CryptoSetup = &cryptoSetupServer{}
|
|
||||||
|
|
||||||
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
|
|
||||||
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
|
|
||||||
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
|
|
||||||
|
|
||||||
// NewCryptoSetup creates a new CryptoSetup instance for a server
|
|
||||||
func NewCryptoSetup(
|
|
||||||
cryptoStream io.ReadWriter,
|
|
||||||
connID protocol.ConnectionID,
|
|
||||||
remoteAddr net.Addr,
|
|
||||||
version protocol.VersionNumber,
|
|
||||||
divNonce []byte,
|
|
||||||
scfg *ServerConfig,
|
|
||||||
params *TransportParameters,
|
|
||||||
supportedVersions []protocol.VersionNumber,
|
|
||||||
acceptSTK func(net.Addr, *Cookie) bool,
|
|
||||||
paramsChan chan<- TransportParameters,
|
|
||||||
handshakeEvent chan<- struct{},
|
|
||||||
logger utils.Logger,
|
|
||||||
) (CryptoSetup, error) {
|
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &cryptoSetupServer{
|
|
||||||
cryptoStream: cryptoStream,
|
|
||||||
connID: connID,
|
|
||||||
remoteAddr: remoteAddr,
|
|
||||||
version: version,
|
|
||||||
supportedVersions: supportedVersions,
|
|
||||||
diversificationNonce: divNonce,
|
|
||||||
scfg: scfg,
|
|
||||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
|
||||||
keyExchange: getEphermalKEX,
|
|
||||||
nullAEAD: nullAEAD,
|
|
||||||
params: params,
|
|
||||||
acceptSTKCallback: acceptSTK,
|
|
||||||
sentSHLO: make(chan struct{}),
|
|
||||||
paramsChan: paramsChan,
|
|
||||||
handshakeEvent: handshakeEvent,
|
|
||||||
logger: logger,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleCryptoStream reads and writes messages on the crypto stream
|
|
||||||
func (h *cryptoSetupServer) HandleCryptoStream() error {
|
|
||||||
for {
|
|
||||||
var chloData bytes.Buffer
|
|
||||||
message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))
|
|
||||||
if err != nil {
|
|
||||||
return qerr.HandshakeFailed
|
|
||||||
}
|
|
||||||
if message.Tag != TagCHLO {
|
|
||||||
return qerr.InvalidCryptoMessageType
|
|
||||||
}
|
|
||||||
|
|
||||||
h.logger.Debugf("Got %s", message)
|
|
||||||
done, err := h.handleMessage(chloData.Bytes(), message.Data)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if done {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) {
|
|
||||||
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
|
|
||||||
return false, ErrNSTPExperiment
|
|
||||||
}
|
|
||||||
|
|
||||||
sniSlice, ok := cryptoData[TagSNI]
|
|
||||||
if !ok {
|
|
||||||
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
|
|
||||||
}
|
|
||||||
sni := string(sniSlice)
|
|
||||||
if sni == "" {
|
|
||||||
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
|
|
||||||
}
|
|
||||||
h.sni = sni
|
|
||||||
|
|
||||||
// prevent version downgrade attacks
|
|
||||||
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
|
|
||||||
verSlice, ok := cryptoData[TagVER]
|
|
||||||
if !ok {
|
|
||||||
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")
|
|
||||||
}
|
|
||||||
if len(verSlice) != 4 {
|
|
||||||
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")
|
|
||||||
}
|
|
||||||
ver := protocol.VersionNumber(binary.BigEndian.Uint32(verSlice))
|
|
||||||
// If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack.
|
|
||||||
if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) {
|
|
||||||
return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
|
||||||
}
|
|
||||||
|
|
||||||
var reply []byte
|
|
||||||
var err error
|
|
||||||
|
|
||||||
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
params, err := readHelloMap(cryptoData)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
// blocks until the session has received the parameters
|
|
||||||
if !h.receivedParams {
|
|
||||||
h.receivedParams = true
|
|
||||||
h.paramsChan <- *params
|
|
||||||
}
|
|
||||||
|
|
||||||
if !h.isInchoateCHLO(cryptoData, certUncompressed) {
|
|
||||||
// We have a CHLO with a proper server config ID, do a 0-RTT handshake
|
|
||||||
reply, err = h.handleCHLO(sni, chloData, cryptoData)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if _, err := h.cryptoStream.Write(reply); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
h.handshakeEvent <- struct{}{}
|
|
||||||
close(h.sentSHLO)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// We have an inchoate or non-matching CHLO, we now send a rejection
|
|
||||||
reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
_, err = h.cryptoStream.Write(reply)
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open a message
|
|
||||||
func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
if h.forwardSecureAEAD != nil {
|
|
||||||
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
|
|
||||||
if err == nil {
|
|
||||||
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
|
|
||||||
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
|
|
||||||
h.receivedForwardSecurePacket = true
|
|
||||||
// wait for the send on the handshakeEvent chan
|
|
||||||
<-h.sentSHLO
|
|
||||||
close(h.handshakeEvent)
|
|
||||||
}
|
|
||||||
return res, protocol.EncryptionForwardSecure, nil
|
|
||||||
}
|
|
||||||
if h.receivedForwardSecurePacket {
|
|
||||||
return nil, protocol.EncryptionUnspecified, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if h.secureAEAD != nil {
|
|
||||||
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
|
||||||
if err == nil {
|
|
||||||
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
|
||||||
h.receivedSecurePacket = true
|
|
||||||
return res, protocol.EncryptionSecure, nil
|
|
||||||
}
|
|
||||||
if h.receivedSecurePacket {
|
|
||||||
return nil, protocol.EncryptionUnspecified, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
|
||||||
if err != nil {
|
|
||||||
return res, protocol.EncryptionUnspecified, err
|
|
||||||
}
|
|
||||||
return res, protocol.EncryptionUnencrypted, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
if h.forwardSecureAEAD != nil {
|
|
||||||
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
|
|
||||||
}
|
|
||||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
if h.secureAEAD != nil {
|
|
||||||
return protocol.EncryptionSecure, h.secureAEAD
|
|
||||||
}
|
|
||||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
switch encLevel {
|
|
||||||
case protocol.EncryptionUnencrypted:
|
|
||||||
return h.nullAEAD, nil
|
|
||||||
case protocol.EncryptionSecure:
|
|
||||||
if h.secureAEAD == nil {
|
|
||||||
return nil, errors.New("CryptoSetupServer: no secureAEAD")
|
|
||||||
}
|
|
||||||
return h.secureAEAD, nil
|
|
||||||
case protocol.EncryptionForwardSecure:
|
|
||||||
if h.forwardSecureAEAD == nil {
|
|
||||||
return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD")
|
|
||||||
}
|
|
||||||
return h.forwardSecureAEAD, nil
|
|
||||||
}
|
|
||||||
return nil, errors.New("CryptoSetupServer: no encryption level specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool {
|
|
||||||
if _, ok := cryptoData[TagPUBS]; !ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
scid, ok := cryptoData[TagSCID]
|
|
||||||
if !ok || !bytes.Equal(h.scfg.ID, scid) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
xlctTag, ok := cryptoData[TagXLCT]
|
|
||||||
if !ok || len(xlctTag) != 8 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
xlct := binary.LittleEndian.Uint64(xlctTag)
|
|
||||||
if crypto.HashCert(cert) != xlct {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return !h.acceptSTK(cryptoData[TagSTK])
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
|
|
||||||
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Debugf("STK invalid: %s", err.Error())
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return h.acceptSTKCallback(h.remoteAddr, stk)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
|
||||||
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
replyMap := map[Tag][]byte{
|
|
||||||
TagSCFG: h.scfg.Get(),
|
|
||||||
TagSTK: token,
|
|
||||||
TagSVID: []byte("quic-go"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.acceptSTK(cryptoData[TagSTK]) {
|
|
||||||
proof, err := h.scfg.Sign(sni, chlo)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
commonSetHashes := cryptoData[TagCCS]
|
|
||||||
cachedCertsHashes := cryptoData[TagCCRT]
|
|
||||||
|
|
||||||
certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Token was valid, send more details
|
|
||||||
replyMap[TagPROF] = proof
|
|
||||||
replyMap[TagCERT] = certCompressed
|
|
||||||
}
|
|
||||||
|
|
||||||
message := HandshakeMessage{
|
|
||||||
Tag: TagREJ,
|
|
||||||
Data: replyMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
var serverReply bytes.Buffer
|
|
||||||
message.Write(&serverReply)
|
|
||||||
h.logger.Debugf("Sending %s", message)
|
|
||||||
return serverReply.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
|
||||||
// We have a CHLO matching our server config, we can continue with the 0-RTT handshake
|
|
||||||
sharedSecret, err := h.scfg.kex.CalculateSharedKey(cryptoData[TagPUBS])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
|
|
||||||
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
serverNonce := make([]byte, 32)
|
|
||||||
if _, err = rand.Read(serverNonce); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clientNonce := cryptoData[TagNONC]
|
|
||||||
err = h.validateClientNonce(clientNonce)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
aead := cryptoData[TagAEAD]
|
|
||||||
if !bytes.Equal(aead, []byte("AESG")) {
|
|
||||||
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
|
|
||||||
}
|
|
||||||
|
|
||||||
kexs := cryptoData[TagKEXS]
|
|
||||||
if !bytes.Equal(kexs, []byte("C255")) {
|
|
||||||
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
|
|
||||||
}
|
|
||||||
|
|
||||||
h.secureAEAD, err = h.keyDerivation(
|
|
||||||
false,
|
|
||||||
sharedSecret,
|
|
||||||
clientNonce,
|
|
||||||
h.connID,
|
|
||||||
data,
|
|
||||||
h.scfg.Get(),
|
|
||||||
certUncompressed,
|
|
||||||
h.diversificationNonce,
|
|
||||||
protocol.PerspectiveServer,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.logger.Debugf("Creating AEAD for secure encryption.")
|
|
||||||
h.handshakeEvent <- struct{}{}
|
|
||||||
|
|
||||||
// Generate a new curve instance to derive the forward secure key
|
|
||||||
var fsNonce bytes.Buffer
|
|
||||||
fsNonce.Write(clientNonce)
|
|
||||||
fsNonce.Write(serverNonce)
|
|
||||||
ephermalKex, err := h.keyExchange()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.forwardSecureAEAD, err = h.keyDerivation(
|
|
||||||
true,
|
|
||||||
ephermalSharedSecret,
|
|
||||||
fsNonce.Bytes(),
|
|
||||||
h.connID,
|
|
||||||
data,
|
|
||||||
h.scfg.Get(),
|
|
||||||
certUncompressed,
|
|
||||||
nil,
|
|
||||||
protocol.PerspectiveServer,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
|
|
||||||
|
|
||||||
replyMap := h.params.getHelloMap()
|
|
||||||
// add crypto parameters
|
|
||||||
verTag := &bytes.Buffer{}
|
|
||||||
for _, v := range h.supportedVersions {
|
|
||||||
utils.BigEndian.WriteUint32(verTag, uint32(v))
|
|
||||||
}
|
|
||||||
replyMap[TagPUBS] = ephermalKex.PublicKey()
|
|
||||||
replyMap[TagSNO] = serverNonce
|
|
||||||
replyMap[TagVER] = verTag.Bytes()
|
|
||||||
|
|
||||||
// note that the SHLO *has* to fit into one packet
|
|
||||||
message := HandshakeMessage{
|
|
||||||
Tag: TagSHLO,
|
|
||||||
Data: replyMap,
|
|
||||||
}
|
|
||||||
var reply bytes.Buffer
|
|
||||||
message.Write(&reply)
|
|
||||||
h.logger.Debugf("Sending %s", message)
|
|
||||||
return reply.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
return ConnectionState{
|
|
||||||
ServerName: h.sni,
|
|
||||||
HandshakeComplete: h.receivedForwardSecurePacket,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
|
||||||
if len(nonce) != 32 {
|
|
||||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(nonce[4:12], h.scfg.obit) {
|
|
||||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
163
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
163
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
generated
vendored
|
@ -1,163 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
// KeyDerivationFunction is used for key derivation
|
|
||||||
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
|
|
||||||
|
|
||||||
type cryptoSetupTLS struct {
|
|
||||||
mutex sync.RWMutex
|
|
||||||
|
|
||||||
perspective protocol.Perspective
|
|
||||||
|
|
||||||
keyDerivation KeyDerivationFunction
|
|
||||||
nullAEAD crypto.AEAD
|
|
||||||
aead crypto.AEAD
|
|
||||||
|
|
||||||
tls mintTLS
|
|
||||||
conn *cryptoStreamConn
|
|
||||||
handshakeEvent chan<- struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ CryptoSetupTLS = &cryptoSetupTLS{}
|
|
||||||
|
|
||||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
|
||||||
func NewCryptoSetupTLSServer(
|
|
||||||
cryptoStream io.ReadWriter,
|
|
||||||
connID protocol.ConnectionID,
|
|
||||||
config *mint.Config,
|
|
||||||
handshakeEvent chan<- struct{},
|
|
||||||
version protocol.VersionNumber,
|
|
||||||
) (CryptoSetupTLS, error) {
|
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn := newCryptoStreamConn(cryptoStream)
|
|
||||||
tls := mint.Server(conn, config)
|
|
||||||
return &cryptoSetupTLS{
|
|
||||||
tls: tls,
|
|
||||||
conn: conn,
|
|
||||||
nullAEAD: nullAEAD,
|
|
||||||
perspective: protocol.PerspectiveServer,
|
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
|
||||||
handshakeEvent: handshakeEvent,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
|
||||||
func NewCryptoSetupTLSClient(
|
|
||||||
cryptoStream io.ReadWriter,
|
|
||||||
connID protocol.ConnectionID,
|
|
||||||
config *mint.Config,
|
|
||||||
handshakeEvent chan<- struct{},
|
|
||||||
version protocol.VersionNumber,
|
|
||||||
) (CryptoSetupTLS, error) {
|
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn := newCryptoStreamConn(cryptoStream)
|
|
||||||
tls := mint.Client(conn, config)
|
|
||||||
return &cryptoSetupTLS{
|
|
||||||
tls: tls,
|
|
||||||
conn: conn,
|
|
||||||
perspective: protocol.PerspectiveClient,
|
|
||||||
nullAEAD: nullAEAD,
|
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
|
||||||
handshakeEvent: handshakeEvent,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
|
||||||
for {
|
|
||||||
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
|
|
||||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
|
||||||
}
|
|
||||||
state := h.tls.ConnectionState().HandshakeState
|
|
||||||
if err := h.conn.Flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if state == mint.StateClientConnected || state == mint.StateServerConnected {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
aead, err := h.keyDerivation(h.tls, h.perspective)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
h.mutex.Lock()
|
|
||||||
h.aead = aead
|
|
||||||
h.mutex.Unlock()
|
|
||||||
|
|
||||||
h.handshakeEvent <- struct{}{}
|
|
||||||
close(h.handshakeEvent)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
|
||||||
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
if h.aead == nil {
|
|
||||||
return nil, errors.New("no 1-RTT sealer")
|
|
||||||
}
|
|
||||||
return h.aead.Open(dst, src, packetNumber, associatedData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
if h.aead != nil {
|
|
||||||
return protocol.EncryptionForwardSecure, h.aead
|
|
||||||
}
|
|
||||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
|
||||||
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", encLevel.String())
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
|
|
||||||
switch encLevel {
|
|
||||||
case protocol.EncryptionUnencrypted:
|
|
||||||
return h.nullAEAD, nil
|
|
||||||
case protocol.EncryptionForwardSecure:
|
|
||||||
if h.aead == nil {
|
|
||||||
return nil, errNoSealer
|
|
||||||
}
|
|
||||||
return h.aead, nil
|
|
||||||
default:
|
|
||||||
return nil, errNoSealer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
|
||||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
mintConnState := h.tls.ConnectionState()
|
|
||||||
return ConnectionState{
|
|
||||||
// TODO: set the ServerName, once mint exports it
|
|
||||||
HandshakeComplete: h.aead != nil,
|
|
||||||
PeerCertificates: mintConnState.PeerCertificates,
|
|
||||||
}
|
|
||||||
}
|
|
69
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
generated
vendored
69
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
generated
vendored
|
@ -1,69 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type cryptoStreamConn struct {
|
|
||||||
buffer *bytes.Buffer
|
|
||||||
stream io.ReadWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ net.Conn = &cryptoStreamConn{}
|
|
||||||
|
|
||||||
func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn {
|
|
||||||
return &cryptoStreamConn{
|
|
||||||
stream: stream,
|
|
||||||
buffer: &bytes.Buffer{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cryptoStreamConn) Read(b []byte) (int, error) {
|
|
||||||
return c.stream.Read(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cryptoStreamConn) Write(p []byte) (int, error) {
|
|
||||||
return c.buffer.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cryptoStreamConn) Flush() error {
|
|
||||||
if c.buffer.Len() == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
_, err := c.stream.Write(c.buffer.Bytes())
|
|
||||||
c.buffer.Reset()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close is not implemented
|
|
||||||
func (c *cryptoStreamConn) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalAddr is not implemented
|
|
||||||
func (c *cryptoStreamConn) LocalAddr() net.Addr {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoteAddr is not implemented
|
|
||||||
func (c *cryptoStreamConn) RemoteAddr() net.Addr {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetReadDeadline is not implemented
|
|
||||||
func (c *cryptoStreamConn) SetReadDeadline(time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWriteDeadline is not implemented
|
|
||||||
func (c *cryptoStreamConn) SetWriteDeadline(time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDeadline is not implemented
|
|
||||||
func (c *cryptoStreamConn) SetDeadline(time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,48 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
kexLifetime = protocol.EphermalKeyLifetime
|
|
||||||
kexCurrent crypto.KeyExchange
|
|
||||||
kexCurrentTime time.Time
|
|
||||||
kexMutex sync.RWMutex
|
|
||||||
)
|
|
||||||
|
|
||||||
// getEphermalKEX returns the currently active KEX, which changes every protocol.EphermalKeyLifetime
|
|
||||||
// See the explanation from the QUIC crypto doc:
|
|
||||||
//
|
|
||||||
// A single connection is the usual scope for forward security, but the security
|
|
||||||
// difference between an ephemeral key used for a single connection, and one
|
|
||||||
// used for all connections for 60 seconds is negligible. Thus we can amortise
|
|
||||||
// the Diffie-Hellman key generation at the server over all the connections in a
|
|
||||||
// small time span.
|
|
||||||
func getEphermalKEX() (crypto.KeyExchange, error) {
|
|
||||||
kexMutex.RLock()
|
|
||||||
res := kexCurrent
|
|
||||||
t := kexCurrentTime
|
|
||||||
kexMutex.RUnlock()
|
|
||||||
if res != nil && time.Since(t) < kexLifetime {
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
kexMutex.Lock()
|
|
||||||
defer kexMutex.Unlock()
|
|
||||||
// Check if still unfulfilled
|
|
||||||
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
|
|
||||||
kex, err := crypto.NewCurve25519KEX()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
kexCurrent = kex
|
|
||||||
kexCurrentTime = time.Now()
|
|
||||||
return kexCurrent, nil
|
|
||||||
}
|
|
||||||
return kexCurrent, nil
|
|
||||||
}
|
|
137
vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go
generated
vendored
137
vendor/github.com/lucas-clemente/quic-go/internal/handshake/handshake_message.go
generated
vendored
|
@ -1,137 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A HandshakeMessage is a handshake message
|
|
||||||
type HandshakeMessage struct {
|
|
||||||
Tag Tag
|
|
||||||
Data map[Tag][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ fmt.Stringer = &HandshakeMessage{}
|
|
||||||
|
|
||||||
// ParseHandshakeMessage reads a crypto message
|
|
||||||
func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) {
|
|
||||||
slice4 := make([]byte, 4)
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, slice4); err != nil {
|
|
||||||
return HandshakeMessage{}, err
|
|
||||||
}
|
|
||||||
messageTag := Tag(binary.LittleEndian.Uint32(slice4))
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, slice4); err != nil {
|
|
||||||
return HandshakeMessage{}, err
|
|
||||||
}
|
|
||||||
nPairs := binary.LittleEndian.Uint32(slice4)
|
|
||||||
|
|
||||||
if nPairs > protocol.CryptoMaxParams {
|
|
||||||
return HandshakeMessage{}, qerr.CryptoTooManyEntries
|
|
||||||
}
|
|
||||||
|
|
||||||
index := make([]byte, nPairs*8)
|
|
||||||
if _, err := io.ReadFull(r, index); err != nil {
|
|
||||||
return HandshakeMessage{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resultMap := map[Tag][]byte{}
|
|
||||||
|
|
||||||
var dataStart uint32
|
|
||||||
for indexPos := 0; indexPos < int(nPairs)*8; indexPos += 8 {
|
|
||||||
tag := Tag(binary.LittleEndian.Uint32(index[indexPos : indexPos+4]))
|
|
||||||
dataEnd := binary.LittleEndian.Uint32(index[indexPos+4 : indexPos+8])
|
|
||||||
|
|
||||||
dataLen := dataEnd - dataStart
|
|
||||||
if dataLen > protocol.CryptoParameterMaxLength {
|
|
||||||
return HandshakeMessage{}, qerr.Error(qerr.CryptoInvalidValueLength, "value too long")
|
|
||||||
}
|
|
||||||
|
|
||||||
data := make([]byte, dataLen)
|
|
||||||
if _, err := io.ReadFull(r, data); err != nil {
|
|
||||||
return HandshakeMessage{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resultMap[tag] = data
|
|
||||||
dataStart = dataEnd
|
|
||||||
}
|
|
||||||
|
|
||||||
return HandshakeMessage{
|
|
||||||
Tag: messageTag,
|
|
||||||
Data: resultMap}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes a crypto message
|
|
||||||
func (h HandshakeMessage) Write(b *bytes.Buffer) {
|
|
||||||
data := h.Data
|
|
||||||
utils.LittleEndian.WriteUint32(b, uint32(h.Tag))
|
|
||||||
utils.LittleEndian.WriteUint16(b, uint16(len(data)))
|
|
||||||
utils.LittleEndian.WriteUint16(b, 0)
|
|
||||||
|
|
||||||
// Save current position in the buffer, so that we can update the index in-place later
|
|
||||||
indexStart := b.Len()
|
|
||||||
|
|
||||||
indexData := make([]byte, 8*len(data))
|
|
||||||
b.Write(indexData) // Will be updated later
|
|
||||||
|
|
||||||
offset := uint32(0)
|
|
||||||
for i, t := range h.getTagsSorted() {
|
|
||||||
v := data[t]
|
|
||||||
b.Write(v)
|
|
||||||
offset += uint32(len(v))
|
|
||||||
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
|
|
||||||
binary.LittleEndian.PutUint32(indexData[i*8+4:], offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now we write the index data for real
|
|
||||||
copy(b.Bytes()[indexStart:], indexData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HandshakeMessage) getTagsSorted() []Tag {
|
|
||||||
tags := make([]Tag, len(h.Data))
|
|
||||||
i := 0
|
|
||||||
for t := range h.Data {
|
|
||||||
tags[i] = t
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
sort.Slice(tags, func(i, j int) bool {
|
|
||||||
return tags[i] < tags[j]
|
|
||||||
})
|
|
||||||
return tags
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h HandshakeMessage) String() string {
|
|
||||||
var pad string
|
|
||||||
res := tagToString(h.Tag) + ":\n"
|
|
||||||
for _, tag := range h.getTagsSorted() {
|
|
||||||
if tag == TagPAD {
|
|
||||||
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
|
|
||||||
} else {
|
|
||||||
res += fmt.Sprintf("\t%s: %#v\n", tagToString(tag), string(h.Data[tag]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(pad) > 0 {
|
|
||||||
res += pad
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func tagToString(tag Tag) string {
|
|
||||||
b := make([]byte, 4)
|
|
||||||
binary.LittleEndian.PutUint32(b, uint32(tag))
|
|
||||||
for i := range b {
|
|
||||||
if b[i] == 0 {
|
|
||||||
b[i] = ' '
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return string(b)
|
|
||||||
}
|
|
|
@ -2,54 +2,43 @@ package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"io"
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/marten-seemann/qtls"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Opener opens a packet
|
||||||
|
type Opener interface {
|
||||||
|
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
// Sealer seals a packet
|
// Sealer seals a packet
|
||||||
type Sealer interface {
|
type Sealer interface {
|
||||||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||||
Overhead() int
|
Overhead() int
|
||||||
}
|
}
|
||||||
|
|
||||||
// mintTLS combines some methods needed to interact with mint.
|
// A tlsExtensionHandler sends and received the QUIC TLS extension.
|
||||||
type mintTLS interface {
|
type tlsExtensionHandler interface {
|
||||||
crypto.TLSExporter
|
GetExtensions(msgType uint8) []qtls.Extension
|
||||||
Handshake() mint.Alert
|
ReceivedExtensions(msgType uint8, exts []qtls.Extension) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// A TLSExtensionHandler sends and received the QUIC TLS extension.
|
// CryptoSetup handles the handshake and protecting / unprotecting packets
|
||||||
// It provides the parameters sent by the peer on a channel.
|
type CryptoSetup interface {
|
||||||
type TLSExtensionHandler interface {
|
RunHandshake() error
|
||||||
Send(mint.HandshakeType, *mint.ExtensionList) error
|
io.Closer
|
||||||
Receive(mint.HandshakeType, *mint.ExtensionList) error
|
|
||||||
GetPeerParams() <-chan TransportParameters
|
|
||||||
}
|
|
||||||
|
|
||||||
type baseCryptoSetup interface {
|
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||||
HandleCryptoStream() error
|
|
||||||
ConnectionState() ConnectionState
|
ConnectionState() ConnectionState
|
||||||
|
|
||||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CryptoSetup is the crypto setup used by gQUIC
|
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||||
type CryptoSetup interface {
|
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||||
baseCryptoSetup
|
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||||
|
|
||||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CryptoSetupTLS is the crypto setup used by IETF QUIC
|
|
||||||
type CryptoSetupTLS interface {
|
|
||||||
baseCryptoSetup
|
|
||||||
|
|
||||||
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
|
||||||
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectionState records basic details about the QUIC connection.
|
// ConnectionState records basic details about the QUIC connection.
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
//go:generate sh -c "../mockgen_internal.sh handshake mock_mint_tls_test.go github.com/lucas-clemente/quic-go/internal/handshake mintTLS"
|
|
48
vendor/github.com/lucas-clemente/quic-go/internal/handshake/qtls.go
generated
vendored
Normal file
48
vendor/github.com/lucas-clemente/quic-go/internal/handshake/qtls.go
generated
vendored
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
|
||||||
|
"github.com/marten-seemann/qtls"
|
||||||
|
)
|
||||||
|
|
||||||
|
func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
|
||||||
|
if c == nil {
|
||||||
|
c = &tls.Config{}
|
||||||
|
}
|
||||||
|
// QUIC requires TLS 1.3 or newer
|
||||||
|
if c.MinVersion < qtls.VersionTLS13 {
|
||||||
|
c.MinVersion = qtls.VersionTLS13
|
||||||
|
}
|
||||||
|
if c.MaxVersion < qtls.VersionTLS13 {
|
||||||
|
c.MaxVersion = qtls.VersionTLS13
|
||||||
|
}
|
||||||
|
return &qtls.Config{
|
||||||
|
Rand: c.Rand,
|
||||||
|
Time: c.Time,
|
||||||
|
Certificates: c.Certificates,
|
||||||
|
NameToCertificate: c.NameToCertificate,
|
||||||
|
// TODO: make GetCertificate work
|
||||||
|
// GetCertificate: c.GetCertificate,
|
||||||
|
GetClientCertificate: c.GetClientCertificate,
|
||||||
|
// TODO: make GetConfigForClient work
|
||||||
|
// GetConfigForClient: c.GetConfigForClient,
|
||||||
|
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||||
|
RootCAs: c.RootCAs,
|
||||||
|
NextProtos: c.NextProtos,
|
||||||
|
ServerName: c.ServerName,
|
||||||
|
ClientAuth: c.ClientAuth,
|
||||||
|
ClientCAs: c.ClientCAs,
|
||||||
|
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||||
|
CipherSuites: c.CipherSuites,
|
||||||
|
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||||
|
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||||
|
SessionTicketKey: c.SessionTicketKey,
|
||||||
|
MinVersion: c.MinVersion,
|
||||||
|
MaxVersion: c.MaxVersion,
|
||||||
|
CurvePreferences: c.CurvePreferences,
|
||||||
|
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||||
|
Renegotiation: c.Renegotiation,
|
||||||
|
KeyLogWriter: c.KeyLogWriter,
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,73 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ServerConfig is a server config
|
|
||||||
type ServerConfig struct {
|
|
||||||
kex crypto.KeyExchange
|
|
||||||
certChain crypto.CertChain
|
|
||||||
ID []byte
|
|
||||||
obit []byte
|
|
||||||
cookieGenerator *CookieGenerator
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewServerConfig creates a new server config
|
|
||||||
func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*ServerConfig, error) {
|
|
||||||
id := make([]byte, 16)
|
|
||||||
_, err := rand.Read(id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
obit := make([]byte, 8)
|
|
||||||
if _, err = rand.Read(obit); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cookieGenerator, err := NewCookieGenerator()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ServerConfig{
|
|
||||||
kex: kex,
|
|
||||||
certChain: certChain,
|
|
||||||
ID: id,
|
|
||||||
obit: obit,
|
|
||||||
cookieGenerator: cookieGenerator,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the server config binary representation
|
|
||||||
func (s *ServerConfig) Get() []byte {
|
|
||||||
var serverConfig bytes.Buffer
|
|
||||||
msg := HandshakeMessage{
|
|
||||||
Tag: TagSCFG,
|
|
||||||
Data: map[Tag][]byte{
|
|
||||||
TagSCID: s.ID,
|
|
||||||
TagKEXS: []byte("C255"),
|
|
||||||
TagAEAD: []byte("AESG"),
|
|
||||||
TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...),
|
|
||||||
TagOBIT: s.obit,
|
|
||||||
TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
msg.Write(&serverConfig)
|
|
||||||
return serverConfig.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign the server config and CHLO with the server's keyData
|
|
||||||
func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) {
|
|
||||||
return s.certChain.SignServerProof(sni, chlo, s.Get())
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCertsCompressed returns the certificate data
|
|
||||||
func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) {
|
|
||||||
return s.certChain.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
|
|
||||||
}
|
|
184
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go
generated
vendored
184
vendor/github.com/lucas-clemente/quic-go/internal/handshake/server_config_client.go
generated
vendored
|
@ -1,184 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"math"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
type serverConfigClient struct {
|
|
||||||
raw []byte
|
|
||||||
ID []byte
|
|
||||||
obit []byte
|
|
||||||
expiry time.Time
|
|
||||||
|
|
||||||
kex crypto.KeyExchange
|
|
||||||
sharedSecret []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG")
|
|
||||||
)
|
|
||||||
|
|
||||||
// parseServerConfig parses a server config
|
|
||||||
func parseServerConfig(data []byte) (*serverConfigClient, error) {
|
|
||||||
message, err := ParseHandshakeMessage(bytes.NewReader(data))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if message.Tag != TagSCFG {
|
|
||||||
return nil, errMessageNotServerConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
scfg := &serverConfigClient{raw: data}
|
|
||||||
err = scfg.parseValues(message.Data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return scfg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
|
|
||||||
// SCID
|
|
||||||
scfgID, ok := tagMap[TagSCID]
|
|
||||||
if !ok {
|
|
||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID")
|
|
||||||
}
|
|
||||||
if len(scfgID) != 16 {
|
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID")
|
|
||||||
}
|
|
||||||
s.ID = scfgID
|
|
||||||
|
|
||||||
// KEXS
|
|
||||||
// TODO: setup Key Exchange
|
|
||||||
kexs, ok := tagMap[TagKEXS]
|
|
||||||
if !ok {
|
|
||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS")
|
|
||||||
}
|
|
||||||
if len(kexs)%4 != 0 {
|
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS")
|
|
||||||
}
|
|
||||||
c255Foundat := -1
|
|
||||||
|
|
||||||
for i := 0; i < len(kexs)/4; i++ {
|
|
||||||
if bytes.Equal(kexs[4*i:4*i+4], []byte("C255")) {
|
|
||||||
c255Foundat = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c255Foundat < 0 {
|
|
||||||
return qerr.Error(qerr.CryptoNoSupport, "KEXS: Could not find C255, other key exchanges are not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
// AEAD
|
|
||||||
aead, ok := tagMap[TagAEAD]
|
|
||||||
if !ok {
|
|
||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD")
|
|
||||||
}
|
|
||||||
if len(aead)%4 != 0 {
|
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD")
|
|
||||||
}
|
|
||||||
var aesgFound bool
|
|
||||||
for i := 0; i < len(aead)/4; i++ {
|
|
||||||
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) {
|
|
||||||
aesgFound = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !aesgFound {
|
|
||||||
return qerr.Error(qerr.CryptoNoSupport, "AEAD")
|
|
||||||
}
|
|
||||||
|
|
||||||
// PUBS
|
|
||||||
pubs, ok := tagMap[TagPUBS]
|
|
||||||
if !ok {
|
|
||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
|
||||||
}
|
|
||||||
|
|
||||||
var pubsKexs []struct {
|
|
||||||
Length uint32
|
|
||||||
Value []byte
|
|
||||||
}
|
|
||||||
var lastLen uint32
|
|
||||||
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
|
|
||||||
// the PUBS value is always prepended by 3 byte little endian length field
|
|
||||||
|
|
||||||
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
|
|
||||||
if err != nil {
|
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
|
|
||||||
}
|
|
||||||
if lastLen == 0 {
|
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
|
||||||
}
|
|
||||||
|
|
||||||
if i+3+int(lastLen) > len(pubs) {
|
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
|
||||||
}
|
|
||||||
|
|
||||||
pubsKexs = append(pubsKexs, struct {
|
|
||||||
Length uint32
|
|
||||||
Value []byte
|
|
||||||
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
|
|
||||||
}
|
|
||||||
|
|
||||||
if c255Foundat >= len(pubsKexs) {
|
|
||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
|
|
||||||
}
|
|
||||||
|
|
||||||
if pubsKexs[c255Foundat].Length != 32 {
|
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
s.kex, err = crypto.NewCurve25519KEX()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// OBIT
|
|
||||||
obit, ok := tagMap[TagOBIT]
|
|
||||||
if !ok {
|
|
||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT")
|
|
||||||
}
|
|
||||||
if len(obit) != 8 {
|
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT")
|
|
||||||
}
|
|
||||||
s.obit = obit
|
|
||||||
|
|
||||||
// EXPY
|
|
||||||
expy, ok := tagMap[TagEXPY]
|
|
||||||
if !ok {
|
|
||||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY")
|
|
||||||
}
|
|
||||||
if len(expy) != 8 {
|
|
||||||
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY")
|
|
||||||
}
|
|
||||||
// make sure that the value doesn't overflow an int64
|
|
||||||
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
|
|
||||||
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2)
|
|
||||||
s.expiry = time.Unix(int64(expyTimestamp), 0)
|
|
||||||
|
|
||||||
// TODO: implement VER
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverConfigClient) IsExpired() bool {
|
|
||||||
return s.expiry.Before(time.Now())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serverConfigClient) Get() []byte {
|
|
||||||
return s.raw
|
|
||||||
}
|
|
|
@ -1,93 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
// A Tag in the QUIC crypto
|
|
||||||
type Tag uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
// TagCHLO is a client hello
|
|
||||||
TagCHLO Tag = 'C' + 'H'<<8 + 'L'<<16 + 'O'<<24
|
|
||||||
// TagREJ is a server hello rejection
|
|
||||||
TagREJ Tag = 'R' + 'E'<<8 + 'J'<<16
|
|
||||||
// TagSCFG is a server config
|
|
||||||
TagSCFG Tag = 'S' + 'C'<<8 + 'F'<<16 + 'G'<<24
|
|
||||||
|
|
||||||
// TagPAD is padding
|
|
||||||
TagPAD Tag = 'P' + 'A'<<8 + 'D'<<16
|
|
||||||
// TagSNI is the server name indication
|
|
||||||
TagSNI Tag = 'S' + 'N'<<8 + 'I'<<16
|
|
||||||
// TagVER is the QUIC version
|
|
||||||
TagVER Tag = 'V' + 'E'<<8 + 'R'<<16
|
|
||||||
// TagCCS are the hashes of the common certificate sets
|
|
||||||
TagCCS Tag = 'C' + 'C'<<8 + 'S'<<16
|
|
||||||
// TagCCRT are the hashes of the cached certificates
|
|
||||||
TagCCRT Tag = 'C' + 'C'<<8 + 'R'<<16 + 'T'<<24
|
|
||||||
// TagMSPC is max streams per connection
|
|
||||||
TagMSPC Tag = 'M' + 'S'<<8 + 'P'<<16 + 'C'<<24
|
|
||||||
// TagMIDS is max incoming dyanamic streams
|
|
||||||
TagMIDS Tag = 'M' + 'I'<<8 + 'D'<<16 + 'S'<<24
|
|
||||||
// TagUAID is the user agent ID
|
|
||||||
TagUAID Tag = 'U' + 'A'<<8 + 'I'<<16 + 'D'<<24
|
|
||||||
// TagSVID is the server ID (unofficial tag by us :)
|
|
||||||
TagSVID Tag = 'S' + 'V'<<8 + 'I'<<16 + 'D'<<24
|
|
||||||
// TagTCID is truncation of the connection ID
|
|
||||||
TagTCID Tag = 'T' + 'C'<<8 + 'I'<<16 + 'D'<<24
|
|
||||||
// TagPDMD is the proof demand
|
|
||||||
TagPDMD Tag = 'P' + 'D'<<8 + 'M'<<16 + 'D'<<24
|
|
||||||
// TagSRBF is the socket receive buffer
|
|
||||||
TagSRBF Tag = 'S' + 'R'<<8 + 'B'<<16 + 'F'<<24
|
|
||||||
// TagICSL is the idle connection state lifetime
|
|
||||||
TagICSL Tag = 'I' + 'C'<<8 + 'S'<<16 + 'L'<<24
|
|
||||||
// TagNONP is the client proof nonce
|
|
||||||
TagNONP Tag = 'N' + 'O'<<8 + 'N'<<16 + 'P'<<24
|
|
||||||
// TagSCLS is the silently close timeout
|
|
||||||
TagSCLS Tag = 'S' + 'C'<<8 + 'L'<<16 + 'S'<<24
|
|
||||||
// TagCSCT is the signed cert timestamp (RFC6962) of leaf cert
|
|
||||||
TagCSCT Tag = 'C' + 'S'<<8 + 'C'<<16 + 'T'<<24
|
|
||||||
// TagCOPT are the connection options
|
|
||||||
TagCOPT Tag = 'C' + 'O'<<8 + 'P'<<16 + 'T'<<24
|
|
||||||
// TagCFCW is the initial session/connection flow control receive window
|
|
||||||
TagCFCW Tag = 'C' + 'F'<<8 + 'C'<<16 + 'W'<<24
|
|
||||||
// TagSFCW is the initial stream flow control receive window.
|
|
||||||
TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24
|
|
||||||
|
|
||||||
// TagNSTP is the no STOP_WAITING experiment
|
|
||||||
// currently unsupported by quic-go
|
|
||||||
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24
|
|
||||||
|
|
||||||
// TagSTK is the source-address token
|
|
||||||
TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16
|
|
||||||
// TagSNO is the server nonce
|
|
||||||
TagSNO Tag = 'S' + 'N'<<8 + 'O'<<16
|
|
||||||
// TagPROF is the server proof
|
|
||||||
TagPROF Tag = 'P' + 'R'<<8 + 'O'<<16 + 'F'<<24
|
|
||||||
|
|
||||||
// TagNONC is the client nonce
|
|
||||||
TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24
|
|
||||||
// TagXLCT is the expected leaf certificate
|
|
||||||
TagXLCT Tag = 'X' + 'L'<<8 + 'C'<<16 + 'T'<<24
|
|
||||||
|
|
||||||
// TagSCID is the server config ID
|
|
||||||
TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24
|
|
||||||
// TagKEXS is the list of key exchange algos
|
|
||||||
TagKEXS Tag = 'K' + 'E'<<8 + 'X'<<16 + 'S'<<24
|
|
||||||
// TagAEAD is the list of AEAD algos
|
|
||||||
TagAEAD Tag = 'A' + 'E'<<8 + 'A'<<16 + 'D'<<24
|
|
||||||
// TagPUBS is the public value for the KEX
|
|
||||||
TagPUBS Tag = 'P' + 'U'<<8 + 'B'<<16 + 'S'<<24
|
|
||||||
// TagOBIT is the client orbit
|
|
||||||
TagOBIT Tag = 'O' + 'B'<<8 + 'I'<<16 + 'T'<<24
|
|
||||||
// TagEXPY is the server config expiry
|
|
||||||
TagEXPY Tag = 'E' + 'X'<<8 + 'P'<<16 + 'Y'<<24
|
|
||||||
// TagCERT is the CERT data
|
|
||||||
TagCERT Tag = 0xff545243
|
|
||||||
|
|
||||||
// TagSHLO is the server hello
|
|
||||||
TagSHLO Tag = 'S' + 'H'<<8 + 'L'<<16 + 'O'<<24
|
|
||||||
|
|
||||||
// TagPRST is the public reset tag
|
|
||||||
TagPRST Tag = 'P' + 'R'<<8 + 'S'<<16 + 'T'<<24
|
|
||||||
// TagRSEQ is the public reset rejected packet number
|
|
||||||
TagRSEQ Tag = 'R' + 'S'<<8 + 'E'<<16 + 'Q'<<24
|
|
||||||
// TagRNON is the public reset nonce
|
|
||||||
TagRNON Tag = 'R' + 'N'<<8 + 'O'<<16 + 'N'<<24
|
|
||||||
)
|
|
|
@ -6,26 +6,12 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type transportParameterID uint16
|
|
||||||
|
|
||||||
const quicTLSExtensionType = 0xff5
|
const quicTLSExtensionType = 0xff5
|
||||||
|
|
||||||
const (
|
|
||||||
initialMaxStreamDataParameterID transportParameterID = 0x0
|
|
||||||
initialMaxDataParameterID transportParameterID = 0x1
|
|
||||||
initialMaxBidiStreamsParameterID transportParameterID = 0x2
|
|
||||||
idleTimeoutParameterID transportParameterID = 0x3
|
|
||||||
maxPacketSizeParameterID transportParameterID = 0x5
|
|
||||||
statelessResetTokenParameterID transportParameterID = 0x6
|
|
||||||
initialMaxUniStreamsParameterID transportParameterID = 0x8
|
|
||||||
disableMigrationParameterID transportParameterID = 0x9
|
|
||||||
)
|
|
||||||
|
|
||||||
type clientHelloTransportParameters struct {
|
type clientHelloTransportParameters struct {
|
||||||
InitialVersion protocol.VersionNumber
|
InitialVersion protocol.VersionNumber
|
||||||
Parameters TransportParameters
|
Parameters TransportParameters
|
||||||
|
@ -52,7 +38,7 @@ func (p *clientHelloTransportParameters) Unmarshal(data []byte) error {
|
||||||
if len(data) != paramsLen {
|
if len(data) != paramsLen {
|
||||||
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
|
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
|
||||||
}
|
}
|
||||||
return p.Parameters.unmarshal(data)
|
return p.Parameters.unmarshal(data, protocol.PerspectiveClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
type encryptedExtensionsTransportParameters struct {
|
type encryptedExtensionsTransportParameters struct {
|
||||||
|
@ -100,24 +86,5 @@ func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error {
|
||||||
if len(data) != paramsLen {
|
if len(data) != paramsLen {
|
||||||
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
|
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
|
||||||
}
|
}
|
||||||
return p.Parameters.unmarshal(data)
|
return p.Parameters.unmarshal(data, protocol.PerspectiveServer)
|
||||||
}
|
|
||||||
|
|
||||||
type tlsExtensionBody struct {
|
|
||||||
data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ mint.ExtensionBody = &tlsExtensionBody{}
|
|
||||||
|
|
||||||
func (e *tlsExtensionBody) Type() mint.ExtensionType {
|
|
||||||
return quicTLSExtensionType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *tlsExtensionBody) Marshal() ([]byte, error) {
|
|
||||||
return e.data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *tlsExtensionBody) Unmarshal(data []byte) (int, error) {
|
|
||||||
e.data = data
|
|
||||||
return len(data), nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,17 +4,17 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/marten-seemann/qtls"
|
||||||
)
|
)
|
||||||
|
|
||||||
type extensionHandlerClient struct {
|
type extensionHandlerClient struct {
|
||||||
ourParams *TransportParameters
|
ourParams *TransportParameters
|
||||||
paramsChan chan TransportParameters
|
paramsChan chan<- TransportParameters
|
||||||
|
|
||||||
|
origConnID protocol.ConnectionID
|
||||||
initialVersion protocol.VersionNumber
|
initialVersion protocol.VersionNumber
|
||||||
supportedVersions []protocol.VersionNumber
|
supportedVersions []protocol.VersionNumber
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
|
@ -22,17 +22,17 @@ type extensionHandlerClient struct {
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
|
var _ tlsExtensionHandler = &extensionHandlerClient{}
|
||||||
var _ TLSExtensionHandler = &extensionHandlerClient{}
|
|
||||||
|
|
||||||
// NewExtensionHandlerClient creates a new extension handler for the client.
|
// newExtensionHandlerClient creates a new extension handler for the client.
|
||||||
func NewExtensionHandlerClient(
|
func newExtensionHandlerClient(
|
||||||
params *TransportParameters,
|
params *TransportParameters,
|
||||||
|
origConnID protocol.ConnectionID,
|
||||||
initialVersion protocol.VersionNumber,
|
initialVersion protocol.VersionNumber,
|
||||||
supportedVersions []protocol.VersionNumber,
|
supportedVersions []protocol.VersionNumber,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
) TLSExtensionHandler {
|
) (tlsExtensionHandler, <-chan TransportParameters) {
|
||||||
// The client reads the transport parameters from the Encrypted Extensions message.
|
// The client reads the transport parameters from the Encrypted Extensions message.
|
||||||
// The paramsChan is used in the session's run loop's select statement.
|
// The paramsChan is used in the session's run loop's select statement.
|
||||||
// We have to use an unbuffered channel here to make sure that the session actually processes the transport parameters immediately.
|
// We have to use an unbuffered channel here to make sure that the session actually processes the transport parameters immediately.
|
||||||
|
@ -40,48 +40,48 @@ func NewExtensionHandlerClient(
|
||||||
return &extensionHandlerClient{
|
return &extensionHandlerClient{
|
||||||
ourParams: params,
|
ourParams: params,
|
||||||
paramsChan: paramsChan,
|
paramsChan: paramsChan,
|
||||||
|
origConnID: origConnID,
|
||||||
initialVersion: initialVersion,
|
initialVersion: initialVersion,
|
||||||
supportedVersions: supportedVersions,
|
supportedVersions: supportedVersions,
|
||||||
version: version,
|
version: version,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}, paramsChan
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
func (h *extensionHandlerClient) GetExtensions(msgType uint8) []qtls.Extension {
|
||||||
if hType != mint.HandshakeTypeClientHello {
|
if messageType(msgType) != typeClientHello {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
|
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
|
||||||
chtp := &clientHelloTransportParameters{
|
return []qtls.Extension{{
|
||||||
|
Type: quicTLSExtensionType,
|
||||||
|
Data: (&clientHelloTransportParameters{
|
||||||
InitialVersion: h.initialVersion,
|
InitialVersion: h.initialVersion,
|
||||||
Parameters: *h.ourParams,
|
Parameters: *h.ourParams,
|
||||||
}
|
}).Marshal(),
|
||||||
return el.Add(&tlsExtensionBody{data: chtp.Marshal()})
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
func (h *extensionHandlerClient) ReceivedExtensions(msgType uint8, exts []qtls.Extension) error {
|
||||||
ext := &tlsExtensionBody{}
|
if messageType(msgType) != typeEncryptedExtensions {
|
||||||
found, err := el.Find(ext)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if hType != mint.HandshakeTypeEncryptedExtensions {
|
|
||||||
if found {
|
|
||||||
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// hType == mint.HandshakeTypeEncryptedExtensions
|
var found bool
|
||||||
|
eetp := &encryptedExtensionsTransportParameters{}
|
||||||
|
for _, ext := range exts {
|
||||||
|
if ext.Type != quicTLSExtensionType {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := eetp.Unmarshal(ext.Data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
found = true
|
||||||
|
}
|
||||||
if !found {
|
if !found {
|
||||||
return errors.New("EncryptedExtensions message didn't contain a QUIC extension")
|
return errors.New("EncryptedExtensions message didn't contain a QUIC extension")
|
||||||
}
|
}
|
||||||
|
|
||||||
eetp := &encryptedExtensionsTransportParameters{}
|
|
||||||
if err := eetp.Unmarshal(ext.data); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// check that the negotiated_version is the current version
|
// check that the negotiated_version is the current version
|
||||||
if eetp.NegotiatedVersion != h.version {
|
if eetp.NegotiatedVersion != h.version {
|
||||||
return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version")
|
return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version")
|
||||||
|
@ -98,15 +98,16 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
params := eetp.Parameters
|
||||||
// check that the server sent a stateless reset token
|
// check that the server sent a stateless reset token
|
||||||
if len(eetp.Parameters.StatelessResetToken) == 0 {
|
if len(params.StatelessResetToken) == 0 {
|
||||||
return errors.New("server didn't sent stateless_reset_token")
|
return errors.New("server didn't sent stateless_reset_token")
|
||||||
}
|
}
|
||||||
h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters)
|
// check the Retry token
|
||||||
h.paramsChan <- eetp.Parameters
|
if !h.origConnID.Equal(params.OriginalConnectionID) {
|
||||||
|
return fmt.Errorf("expected original_connection_id to equal %s, is %s", h.origConnID, params.OriginalConnectionID)
|
||||||
|
}
|
||||||
|
h.logger.Debugf("Received Transport Parameters: %s", ¶ms)
|
||||||
|
h.paramsChan <- params
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters {
|
|
||||||
return h.paramsChan
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,18 +2,16 @@ package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/marten-seemann/qtls"
|
||||||
)
|
)
|
||||||
|
|
||||||
type extensionHandlerServer struct {
|
type extensionHandlerServer struct {
|
||||||
ourParams *TransportParameters
|
ourParams *TransportParameters
|
||||||
paramsChan chan TransportParameters
|
paramsChan chan<- TransportParameters
|
||||||
|
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
supportedVersions []protocol.VersionNumber
|
supportedVersions []protocol.VersionNumber
|
||||||
|
@ -21,62 +19,60 @@ type extensionHandlerServer struct {
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
|
var _ tlsExtensionHandler = &extensionHandlerServer{}
|
||||||
var _ TLSExtensionHandler = &extensionHandlerServer{}
|
|
||||||
|
|
||||||
// NewExtensionHandlerServer creates a new extension handler for the server
|
// newExtensionHandlerServer creates a new extension handler for the server
|
||||||
func NewExtensionHandlerServer(
|
func newExtensionHandlerServer(
|
||||||
params *TransportParameters,
|
params *TransportParameters,
|
||||||
supportedVersions []protocol.VersionNumber,
|
supportedVersions []protocol.VersionNumber,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
logger utils.Logger,
|
logger utils.Logger,
|
||||||
) TLSExtensionHandler {
|
) (tlsExtensionHandler, <-chan TransportParameters) {
|
||||||
// Processing the ClientHello is performed statelessly (and from a single go-routine).
|
// Processing the ClientHello is performed statelessly (and from a single go-routine).
|
||||||
// Therefore, we have to use a buffered chan to pass the transport parameters to that go routine.
|
// Therefore, we have to use a buffered chan to pass the transport parameters to that go routine.
|
||||||
paramsChan := make(chan TransportParameters, 1)
|
paramsChan := make(chan TransportParameters)
|
||||||
return &extensionHandlerServer{
|
return &extensionHandlerServer{
|
||||||
ourParams: params,
|
ourParams: params,
|
||||||
paramsChan: paramsChan,
|
paramsChan: paramsChan,
|
||||||
supportedVersions: supportedVersions,
|
supportedVersions: supportedVersions,
|
||||||
version: version,
|
version: version,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}, paramsChan
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
func (h *extensionHandlerServer) GetExtensions(msgType uint8) []qtls.Extension {
|
||||||
if hType != mint.HandshakeTypeEncryptedExtensions {
|
if messageType(msgType) != typeEncryptedExtensions {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
|
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
|
||||||
eetp := &encryptedExtensionsTransportParameters{
|
return []qtls.Extension{{
|
||||||
|
Type: quicTLSExtensionType,
|
||||||
|
Data: (&encryptedExtensionsTransportParameters{
|
||||||
NegotiatedVersion: h.version,
|
NegotiatedVersion: h.version,
|
||||||
SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
|
SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
|
||||||
Parameters: *h.ourParams,
|
Parameters: *h.ourParams,
|
||||||
}
|
}).Marshal(),
|
||||||
return el.Add(&tlsExtensionBody{data: eetp.Marshal()})
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
|
func (h *extensionHandlerServer) ReceivedExtensions(msgType uint8, exts []qtls.Extension) error {
|
||||||
ext := &tlsExtensionBody{}
|
if messageType(msgType) != typeClientHello {
|
||||||
found, err := el.Find(ext)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if hType != mint.HandshakeTypeClientHello {
|
|
||||||
if found {
|
|
||||||
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
var found bool
|
||||||
|
chtp := &clientHelloTransportParameters{}
|
||||||
|
for _, ext := range exts {
|
||||||
|
if ext.Type != quicTLSExtensionType {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := chtp.Unmarshal(ext.Data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
found = true
|
||||||
|
}
|
||||||
if !found {
|
if !found {
|
||||||
return errors.New("ClientHello didn't contain a QUIC extension")
|
return errors.New("ClientHello didn't contain a QUIC extension")
|
||||||
}
|
}
|
||||||
chtp := &clientHelloTransportParameters{}
|
|
||||||
if err := chtp.Unmarshal(ext.data); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// perform the stateless version negotiation validation:
|
// perform the stateless version negotiation validation:
|
||||||
// make sure that we would have sent a Version Negotiation Packet if the client offered the initial version
|
// make sure that we would have sent a Version Negotiation Packet if the client offered the initial version
|
||||||
|
@ -84,17 +80,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
|
||||||
if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) {
|
if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) {
|
||||||
return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version")
|
return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version")
|
||||||
}
|
}
|
||||||
|
|
||||||
// check that the client didn't send a stateless reset token
|
|
||||||
if len(chtp.Parameters.StatelessResetToken) != 0 {
|
|
||||||
// TODO: return the correct error type
|
|
||||||
return errors.New("client sent a stateless reset token")
|
|
||||||
}
|
|
||||||
h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters)
|
h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters)
|
||||||
h.paramsChan <- chtp.Parameters
|
h.paramsChan <- chtp.Parameters
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters {
|
|
||||||
return h.paramsChan
|
|
||||||
}
|
|
||||||
|
|
287
vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go
generated
vendored
287
vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go
generated
vendored
|
@ -2,194 +2,190 @@ package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// errMalformedTag is returned when the tag value cannot be read
|
type transportParameterID uint16
|
||||||
var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
|
|
||||||
|
const (
|
||||||
|
originalConnectionIDParameterID transportParameterID = 0x0
|
||||||
|
idleTimeoutParameterID transportParameterID = 0x1
|
||||||
|
statelessResetTokenParameterID transportParameterID = 0x2
|
||||||
|
maxPacketSizeParameterID transportParameterID = 0x3
|
||||||
|
initialMaxDataParameterID transportParameterID = 0x4
|
||||||
|
initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5
|
||||||
|
initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6
|
||||||
|
initialMaxStreamDataUniParameterID transportParameterID = 0x7
|
||||||
|
initialMaxStreamsBidiParameterID transportParameterID = 0x8
|
||||||
|
initialMaxStreamsUniParameterID transportParameterID = 0x9
|
||||||
|
disableMigrationParameterID transportParameterID = 0xc
|
||||||
|
)
|
||||||
|
|
||||||
// TransportParameters are parameters sent to the peer during the handshake
|
// TransportParameters are parameters sent to the peer during the handshake
|
||||||
type TransportParameters struct {
|
type TransportParameters struct {
|
||||||
StreamFlowControlWindow protocol.ByteCount
|
InitialMaxStreamDataBidiLocal protocol.ByteCount
|
||||||
ConnectionFlowControlWindow protocol.ByteCount
|
InitialMaxStreamDataBidiRemote protocol.ByteCount
|
||||||
|
InitialMaxStreamDataUni protocol.ByteCount
|
||||||
|
InitialMaxData protocol.ByteCount
|
||||||
|
|
||||||
MaxPacketSize protocol.ByteCount
|
MaxPacketSize protocol.ByteCount
|
||||||
|
|
||||||
MaxUniStreams uint16 // only used for IETF QUIC
|
MaxUniStreams uint64
|
||||||
MaxBidiStreams uint16 // only used for IETF QUIC
|
MaxBidiStreams uint64
|
||||||
MaxStreams uint32 // only used for gQUIC
|
|
||||||
|
|
||||||
OmitConnectionID bool // only used for gQUIC
|
|
||||||
IdleTimeout time.Duration
|
IdleTimeout time.Duration
|
||||||
DisableMigration bool // only used for IETF QUIC
|
DisableMigration bool
|
||||||
StatelessResetToken []byte // only used for IETF QUIC
|
|
||||||
|
StatelessResetToken []byte
|
||||||
|
OriginalConnectionID protocol.ConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
|
func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective) error {
|
||||||
func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) {
|
// needed to check that every parameter is only sent at most once
|
||||||
params := &TransportParameters{}
|
var parameterIDs []transportParameterID
|
||||||
if value, ok := tags[TagTCID]; ok {
|
|
||||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return nil, errMalformedTag
|
|
||||||
}
|
|
||||||
params.OmitConnectionID = (v == 0)
|
|
||||||
}
|
|
||||||
if value, ok := tags[TagMIDS]; ok {
|
|
||||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return nil, errMalformedTag
|
|
||||||
}
|
|
||||||
params.MaxStreams = v
|
|
||||||
}
|
|
||||||
if value, ok := tags[TagICSL]; ok {
|
|
||||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return nil, errMalformedTag
|
|
||||||
}
|
|
||||||
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second)
|
|
||||||
}
|
|
||||||
if value, ok := tags[TagSFCW]; ok {
|
|
||||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return nil, errMalformedTag
|
|
||||||
}
|
|
||||||
params.StreamFlowControlWindow = protocol.ByteCount(v)
|
|
||||||
}
|
|
||||||
if value, ok := tags[TagCFCW]; ok {
|
|
||||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
|
||||||
if err != nil {
|
|
||||||
return nil, errMalformedTag
|
|
||||||
}
|
|
||||||
params.ConnectionFlowControlWindow = protocol.ByteCount(v)
|
|
||||||
}
|
|
||||||
return params, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake.
|
r := bytes.NewReader(data)
|
||||||
func (p *TransportParameters) getHelloMap() map[Tag][]byte {
|
for r.Len() >= 4 {
|
||||||
sfcw := bytes.NewBuffer([]byte{})
|
paramIDInt, _ := utils.BigEndian.ReadUint16(r)
|
||||||
utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow))
|
paramID := transportParameterID(paramIDInt)
|
||||||
cfcw := bytes.NewBuffer([]byte{})
|
paramLen, _ := utils.BigEndian.ReadUint16(r)
|
||||||
utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow))
|
parameterIDs = append(parameterIDs, paramID)
|
||||||
mids := bytes.NewBuffer([]byte{})
|
switch paramID {
|
||||||
utils.LittleEndian.WriteUint32(mids, p.MaxStreams)
|
case initialMaxStreamDataBidiLocalParameterID,
|
||||||
icsl := bytes.NewBuffer([]byte{})
|
initialMaxStreamDataBidiRemoteParameterID,
|
||||||
utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second))
|
initialMaxStreamDataUniParameterID,
|
||||||
|
initialMaxDataParameterID,
|
||||||
tags := map[Tag][]byte{
|
initialMaxStreamsBidiParameterID,
|
||||||
TagICSL: icsl.Bytes(),
|
initialMaxStreamsUniParameterID,
|
||||||
TagMIDS: mids.Bytes(),
|
idleTimeoutParameterID,
|
||||||
TagCFCW: cfcw.Bytes(),
|
maxPacketSizeParameterID:
|
||||||
TagSFCW: sfcw.Bytes(),
|
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if p.OmitConnectionID {
|
default:
|
||||||
tags[TagTCID] = []byte{0, 0, 0, 0}
|
if r.Len() < int(paramLen) {
|
||||||
|
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen)
|
||||||
}
|
}
|
||||||
return tags
|
switch paramID {
|
||||||
}
|
|
||||||
|
|
||||||
func (p *TransportParameters) unmarshal(data []byte) error {
|
|
||||||
var foundIdleTimeout bool
|
|
||||||
|
|
||||||
for len(data) >= 4 {
|
|
||||||
paramID := binary.BigEndian.Uint16(data[:2])
|
|
||||||
paramLen := int(binary.BigEndian.Uint16(data[2:4]))
|
|
||||||
data = data[4:]
|
|
||||||
if len(data) < paramLen {
|
|
||||||
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(data), paramLen)
|
|
||||||
}
|
|
||||||
switch transportParameterID(paramID) {
|
|
||||||
case initialMaxStreamDataParameterID:
|
|
||||||
if paramLen != 4 {
|
|
||||||
return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", paramLen)
|
|
||||||
}
|
|
||||||
p.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
|
|
||||||
case initialMaxDataParameterID:
|
|
||||||
if paramLen != 4 {
|
|
||||||
return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", paramLen)
|
|
||||||
}
|
|
||||||
p.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
|
|
||||||
case initialMaxBidiStreamsParameterID:
|
|
||||||
if paramLen != 2 {
|
|
||||||
return fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", paramLen)
|
|
||||||
}
|
|
||||||
p.MaxBidiStreams = binary.BigEndian.Uint16(data[:2])
|
|
||||||
case initialMaxUniStreamsParameterID:
|
|
||||||
if paramLen != 2 {
|
|
||||||
return fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", paramLen)
|
|
||||||
}
|
|
||||||
p.MaxUniStreams = binary.BigEndian.Uint16(data[:2])
|
|
||||||
case idleTimeoutParameterID:
|
|
||||||
foundIdleTimeout = true
|
|
||||||
if paramLen != 2 {
|
|
||||||
return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", paramLen)
|
|
||||||
}
|
|
||||||
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(data[:2]))*time.Second)
|
|
||||||
case maxPacketSizeParameterID:
|
|
||||||
if paramLen != 2 {
|
|
||||||
return fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", paramLen)
|
|
||||||
}
|
|
||||||
maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(data[:2]))
|
|
||||||
if maxPacketSize < 1200 {
|
|
||||||
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
|
|
||||||
}
|
|
||||||
p.MaxPacketSize = maxPacketSize
|
|
||||||
case disableMigrationParameterID:
|
case disableMigrationParameterID:
|
||||||
if paramLen != 0 {
|
if paramLen != 0 {
|
||||||
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
|
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
|
||||||
}
|
}
|
||||||
p.DisableMigration = true
|
p.DisableMigration = true
|
||||||
case statelessResetTokenParameterID:
|
case statelessResetTokenParameterID:
|
||||||
|
if sentBy == protocol.PerspectiveClient {
|
||||||
|
return errors.New("client sent a stateless_reset_token")
|
||||||
|
}
|
||||||
if paramLen != 16 {
|
if paramLen != 16 {
|
||||||
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
|
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
|
||||||
}
|
}
|
||||||
p.StatelessResetToken = data[:16]
|
b := make([]byte, 16)
|
||||||
|
r.Read(b)
|
||||||
|
p.StatelessResetToken = b
|
||||||
|
case originalConnectionIDParameterID:
|
||||||
|
if sentBy == protocol.PerspectiveClient {
|
||||||
|
return errors.New("client sent an original_connection_id")
|
||||||
|
}
|
||||||
|
p.OriginalConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
|
||||||
|
default:
|
||||||
|
r.Seek(int64(paramLen), io.SeekCurrent)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
data = data[paramLen:]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(data) != 0 {
|
// check that every transport parameter was sent at most once
|
||||||
return fmt.Errorf("should have read all data. Still have %d bytes", len(data))
|
sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] })
|
||||||
|
for i := 0; i < len(parameterIDs)-1; i++ {
|
||||||
|
if parameterIDs[i] == parameterIDs[i+1] {
|
||||||
|
return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i])
|
||||||
}
|
}
|
||||||
if !foundIdleTimeout {
|
}
|
||||||
return errors.New("missing parameter")
|
|
||||||
|
if r.Len() != 0 {
|
||||||
|
return fmt.Errorf("should have read all data. Still have %d bytes", r.Len())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TransportParameters) readNumericTransportParameter(
|
||||||
|
r *bytes.Reader,
|
||||||
|
paramID transportParameterID,
|
||||||
|
expectedLen int,
|
||||||
|
) error {
|
||||||
|
remainingLen := r.Len()
|
||||||
|
val, err := utils.ReadVarInt(r)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err)
|
||||||
|
}
|
||||||
|
if remainingLen-r.Len() != expectedLen {
|
||||||
|
return fmt.Errorf("inconsistent transport parameter length for %d", paramID)
|
||||||
|
}
|
||||||
|
switch paramID {
|
||||||
|
case initialMaxStreamDataBidiLocalParameterID:
|
||||||
|
p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val)
|
||||||
|
case initialMaxStreamDataBidiRemoteParameterID:
|
||||||
|
p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val)
|
||||||
|
case initialMaxStreamDataUniParameterID:
|
||||||
|
p.InitialMaxStreamDataUni = protocol.ByteCount(val)
|
||||||
|
case initialMaxDataParameterID:
|
||||||
|
p.InitialMaxData = protocol.ByteCount(val)
|
||||||
|
case initialMaxStreamsBidiParameterID:
|
||||||
|
p.MaxBidiStreams = val
|
||||||
|
case initialMaxStreamsUniParameterID:
|
||||||
|
p.MaxUniStreams = val
|
||||||
|
case idleTimeoutParameterID:
|
||||||
|
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Second)
|
||||||
|
case maxPacketSizeParameterID:
|
||||||
|
if val < 1200 {
|
||||||
|
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val)
|
||||||
|
}
|
||||||
|
p.MaxPacketSize = protocol.ByteCount(val)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TransportParameters) marshal(b *bytes.Buffer) {
|
func (p *TransportParameters) marshal(b *bytes.Buffer) {
|
||||||
// initial_max_stream_data
|
// initial_max_stream_data_bidi_local
|
||||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataParameterID))
|
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiLocalParameterID))
|
||||||
utils.BigEndian.WriteUint16(b, 4)
|
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataBidiLocal))))
|
||||||
utils.BigEndian.WriteUint32(b, uint32(p.StreamFlowControlWindow))
|
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataBidiLocal))
|
||||||
|
// initial_max_stream_data_bidi_remote
|
||||||
|
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiRemoteParameterID))
|
||||||
|
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataBidiRemote))))
|
||||||
|
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataBidiRemote))
|
||||||
|
// initial_max_stream_data_uni
|
||||||
|
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataUniParameterID))
|
||||||
|
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataUni))))
|
||||||
|
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataUni))
|
||||||
// initial_max_data
|
// initial_max_data
|
||||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID))
|
utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID))
|
||||||
utils.BigEndian.WriteUint16(b, 4)
|
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxData))))
|
||||||
utils.BigEndian.WriteUint32(b, uint32(p.ConnectionFlowControlWindow))
|
utils.WriteVarInt(b, uint64(p.InitialMaxData))
|
||||||
// initial_max_bidi_streams
|
// initial_max_bidi_streams
|
||||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxBidiStreamsParameterID))
|
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamsBidiParameterID))
|
||||||
utils.BigEndian.WriteUint16(b, 2)
|
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(p.MaxBidiStreams)))
|
||||||
utils.BigEndian.WriteUint16(b, p.MaxBidiStreams)
|
utils.WriteVarInt(b, p.MaxBidiStreams)
|
||||||
// initial_max_uni_streams
|
// initial_max_uni_streams
|
||||||
utils.BigEndian.WriteUint16(b, uint16(initialMaxUniStreamsParameterID))
|
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamsUniParameterID))
|
||||||
utils.BigEndian.WriteUint16(b, 2)
|
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(p.MaxUniStreams)))
|
||||||
utils.BigEndian.WriteUint16(b, p.MaxUniStreams)
|
utils.WriteVarInt(b, p.MaxUniStreams)
|
||||||
// idle_timeout
|
// idle_timeout
|
||||||
utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID))
|
utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID))
|
||||||
utils.BigEndian.WriteUint16(b, 2)
|
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.IdleTimeout/time.Second))))
|
||||||
utils.BigEndian.WriteUint16(b, uint16(p.IdleTimeout/time.Second))
|
utils.WriteVarInt(b, uint64(p.IdleTimeout/time.Second))
|
||||||
// max_packet_size
|
// max_packet_size
|
||||||
utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID))
|
utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID))
|
||||||
utils.BigEndian.WriteUint16(b, 2)
|
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(protocol.MaxReceivePacketSize))))
|
||||||
utils.BigEndian.WriteUint16(b, uint16(protocol.MaxReceivePacketSize))
|
utils.WriteVarInt(b, uint64(protocol.MaxReceivePacketSize))
|
||||||
// disable_migration
|
// disable_migration
|
||||||
if p.DisableMigration {
|
if p.DisableMigration {
|
||||||
utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID))
|
utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID))
|
||||||
|
@ -200,10 +196,15 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) {
|
||||||
utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes
|
utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes
|
||||||
b.Write(p.StatelessResetToken)
|
b.Write(p.StatelessResetToken)
|
||||||
}
|
}
|
||||||
|
// original_connection_id
|
||||||
|
if p.OriginalConnectionID.Len() > 0 {
|
||||||
|
utils.BigEndian.WriteUint16(b, uint16(originalConnectionIDParameterID))
|
||||||
|
utils.BigEndian.WriteUint16(b, uint16(p.OriginalConnectionID.Len()))
|
||||||
|
b.Write(p.OriginalConnectionID.Bytes())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns a string representation, intended for logging.
|
// String returns a string representation, intended for logging.
|
||||||
// It should only used for IETF QUIC.
|
|
||||||
func (p *TransportParameters) String() string {
|
func (p *TransportParameters) String() string {
|
||||||
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
|
return fmt.Sprintf("&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,30 +86,6 @@ func (mr *MockSentPacketHandlerMockRecorder) GetLowestPacketNotConfirmedAcked()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLowestPacketNotConfirmedAcked))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPacketNumberLen mocks base method
|
|
||||||
func (m *MockSentPacketHandler) GetPacketNumberLen(arg0 protocol.PacketNumber) protocol.PacketNumberLen {
|
|
||||||
ret := m.ctrl.Call(m, "GetPacketNumberLen", arg0)
|
|
||||||
ret0, _ := ret[0].(protocol.PacketNumberLen)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPacketNumberLen indicates an expected call of GetPacketNumberLen
|
|
||||||
func (mr *MockSentPacketHandlerMockRecorder) GetPacketNumberLen(arg0 interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPacketNumberLen", reflect.TypeOf((*MockSentPacketHandler)(nil).GetPacketNumberLen), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetStopWaitingFrame mocks base method
|
|
||||||
func (m *MockSentPacketHandler) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame {
|
|
||||||
ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0)
|
|
||||||
ret0, _ := ret[0].(*wire.StopWaitingFrame)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetStopWaitingFrame indicates an expected call of GetStopWaitingFrame
|
|
||||||
func (mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStopWaitingFrame", reflect.TypeOf((*MockSentPacketHandler)(nil).GetStopWaitingFrame), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnAlarm mocks base method
|
// OnAlarm mocks base method
|
||||||
func (m *MockSentPacketHandler) OnAlarm() error {
|
func (m *MockSentPacketHandler) OnAlarm() error {
|
||||||
ret := m.ctrl.Call(m, "OnAlarm")
|
ret := m.ctrl.Call(m, "OnAlarm")
|
||||||
|
@ -122,6 +98,31 @@ func (mr *MockSentPacketHandlerMockRecorder) OnAlarm() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAlarm", reflect.TypeOf((*MockSentPacketHandler)(nil).OnAlarm))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAlarm", reflect.TypeOf((*MockSentPacketHandler)(nil).OnAlarm))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeekPacketNumber mocks base method
|
||||||
|
func (m *MockSentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||||
|
ret := m.ctrl.Call(m, "PeekPacketNumber")
|
||||||
|
ret0, _ := ret[0].(protocol.PacketNumber)
|
||||||
|
ret1, _ := ret[1].(protocol.PacketNumberLen)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeekPacketNumber indicates an expected call of PeekPacketNumber
|
||||||
|
func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber() *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PopPacketNumber mocks base method
|
||||||
|
func (m *MockSentPacketHandler) PopPacketNumber() protocol.PacketNumber {
|
||||||
|
ret := m.ctrl.Call(m, "PopPacketNumber")
|
||||||
|
ret0, _ := ret[0].(protocol.PacketNumber)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// PopPacketNumber indicates an expected call of PopPacketNumber
|
||||||
|
func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber() *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber))
|
||||||
|
}
|
||||||
|
|
||||||
// ReceivedAck mocks base method
|
// ReceivedAck mocks base method
|
||||||
func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.PacketNumber, arg2 protocol.EncryptionLevel, arg3 time.Time) error {
|
func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.PacketNumber, arg2 protocol.EncryptionLevel, arg3 time.Time) error {
|
||||||
ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2, arg3)
|
ret := m.ctrl.Call(m, "ReceivedAck", arg0, arg1, arg2, arg3)
|
||||||
|
|
149
vendor/github.com/lucas-clemente/quic-go/internal/mocks/crypto_setup.go
generated
vendored
Normal file
149
vendor/github.com/lucas-clemente/quic-go/internal/mocks/crypto_setup.go
generated
vendored
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: CryptoSetup)
|
||||||
|
|
||||||
|
// Package mocks is a generated GoMock package.
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCryptoSetup is a mock of CryptoSetup interface
|
||||||
|
type MockCryptoSetup struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockCryptoSetupMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockCryptoSetupMockRecorder is the mock recorder for MockCryptoSetup
|
||||||
|
type MockCryptoSetupMockRecorder struct {
|
||||||
|
mock *MockCryptoSetup
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockCryptoSetup creates a new mock instance
|
||||||
|
func NewMockCryptoSetup(ctrl *gomock.Controller) *MockCryptoSetup {
|
||||||
|
mock := &MockCryptoSetup{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockCryptoSetupMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close mocks base method
|
||||||
|
func (m *MockCryptoSetup) Close() error {
|
||||||
|
ret := m.ctrl.Call(m, "Close")
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close indicates an expected call of Close
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionState mocks base method
|
||||||
|
func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState {
|
||||||
|
ret := m.ctrl.Call(m, "ConnectionState")
|
||||||
|
ret0, _ := ret[0].(handshake.ConnectionState)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionState indicates an expected call of ConnectionState
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSealer mocks base method
|
||||||
|
func (m *MockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) {
|
||||||
|
ret := m.ctrl.Call(m, "GetSealer")
|
||||||
|
ret0, _ := ret[0].(protocol.EncryptionLevel)
|
||||||
|
ret1, _ := ret[1].(handshake.Sealer)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSealer indicates an expected call of GetSealer
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) GetSealer() *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetSealer))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSealerWithEncryptionLevel mocks base method
|
||||||
|
func (m *MockCryptoSetup) GetSealerWithEncryptionLevel(arg0 protocol.EncryptionLevel) (handshake.Sealer, error) {
|
||||||
|
ret := m.ctrl.Call(m, "GetSealerWithEncryptionLevel", arg0)
|
||||||
|
ret0, _ := ret[0].(handshake.Sealer)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSealerWithEncryptionLevel indicates an expected call of GetSealerWithEncryptionLevel
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) GetSealerWithEncryptionLevel(arg0 interface{}) *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealerWithEncryptionLevel", reflect.TypeOf((*MockCryptoSetup)(nil).GetSealerWithEncryptionLevel), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleMessage mocks base method
|
||||||
|
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
|
||||||
|
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleMessage indicates an expected call of HandleMessage
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open1RTT mocks base method
|
||||||
|
func (m *MockCryptoSetup) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||||
|
ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3)
|
||||||
|
ret0, _ := ret[0].([]byte)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open1RTT indicates an expected call of Open1RTT
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockCryptoSetup)(nil).Open1RTT), arg0, arg1, arg2, arg3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenHandshake mocks base method
|
||||||
|
func (m *MockCryptoSetup) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||||
|
ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3)
|
||||||
|
ret0, _ := ret[0].([]byte)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenHandshake indicates an expected call of OpenHandshake
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).OpenHandshake), arg0, arg1, arg2, arg3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenInitial mocks base method
|
||||||
|
func (m *MockCryptoSetup) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||||
|
ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3)
|
||||||
|
ret0, _ := ret[0].([]byte)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenInitial indicates an expected call of OpenInitial
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockCryptoSetup)(nil).OpenInitial), arg0, arg1, arg2, arg3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunHandshake mocks base method
|
||||||
|
func (m *MockCryptoSetup) RunHandshake() error {
|
||||||
|
ret := m.ctrl.Call(m, "RunHandshake")
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunHandshake indicates an expected call of RunHandshake
|
||||||
|
func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake))
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
//go:generate sh -c "../mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
|
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
|
||||||
|
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
|
||||||
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
||||||
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
|
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
|
||||||
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"
|
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Sealer)
|
||||||
|
|
||||||
|
// Package mocks is a generated GoMock package.
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockSealer is a mock of Sealer interface
|
||||||
|
type MockSealer struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockSealerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockSealerMockRecorder is the mock recorder for MockSealer
|
||||||
|
type MockSealerMockRecorder struct {
|
||||||
|
mock *MockSealer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockSealer creates a new mock instance
|
||||||
|
func NewMockSealer(ctrl *gomock.Controller) *MockSealer {
|
||||||
|
mock := &MockSealer{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockSealerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockSealer) EXPECT() *MockSealerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overhead mocks base method
|
||||||
|
func (m *MockSealer) Overhead() int {
|
||||||
|
ret := m.ctrl.Call(m, "Overhead")
|
||||||
|
ret0, _ := ret[0].(int)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overhead indicates an expected call of Overhead
|
||||||
|
func (mr *MockSealerMockRecorder) Overhead() *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockSealer)(nil).Overhead))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seal mocks base method
|
||||||
|
func (m *MockSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte {
|
||||||
|
ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3)
|
||||||
|
ret0, _ := ret[0].([]byte)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seal indicates an expected call of Seal
|
||||||
|
func (mr *MockSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockSealer)(nil).Seal), arg0, arg1, arg2, arg3)
|
||||||
|
}
|
72
vendor/github.com/lucas-clemente/quic-go/internal/mocks/tls_extension_handler.go
generated
vendored
72
vendor/github.com/lucas-clemente/quic-go/internal/mocks/tls_extension_handler.go
generated
vendored
|
@ -1,72 +0,0 @@
|
||||||
// Code generated by MockGen. DO NOT EDIT.
|
|
||||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: TLSExtensionHandler)
|
|
||||||
|
|
||||||
// Package mocks is a generated GoMock package.
|
|
||||||
package mocks
|
|
||||||
|
|
||||||
import (
|
|
||||||
reflect "reflect"
|
|
||||||
|
|
||||||
mint "github.com/bifurcation/mint"
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
|
||||||
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockTLSExtensionHandler is a mock of TLSExtensionHandler interface
|
|
||||||
type MockTLSExtensionHandler struct {
|
|
||||||
ctrl *gomock.Controller
|
|
||||||
recorder *MockTLSExtensionHandlerMockRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockTLSExtensionHandlerMockRecorder is the mock recorder for MockTLSExtensionHandler
|
|
||||||
type MockTLSExtensionHandlerMockRecorder struct {
|
|
||||||
mock *MockTLSExtensionHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockTLSExtensionHandler creates a new mock instance
|
|
||||||
func NewMockTLSExtensionHandler(ctrl *gomock.Controller) *MockTLSExtensionHandler {
|
|
||||||
mock := &MockTLSExtensionHandler{ctrl: ctrl}
|
|
||||||
mock.recorder = &MockTLSExtensionHandlerMockRecorder{mock}
|
|
||||||
return mock
|
|
||||||
}
|
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use
|
|
||||||
func (m *MockTLSExtensionHandler) EXPECT() *MockTLSExtensionHandlerMockRecorder {
|
|
||||||
return m.recorder
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerParams mocks base method
|
|
||||||
func (m *MockTLSExtensionHandler) GetPeerParams() <-chan handshake.TransportParameters {
|
|
||||||
ret := m.ctrl.Call(m, "GetPeerParams")
|
|
||||||
ret0, _ := ret[0].(<-chan handshake.TransportParameters)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerParams indicates an expected call of GetPeerParams
|
|
||||||
func (mr *MockTLSExtensionHandlerMockRecorder) GetPeerParams() *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerParams", reflect.TypeOf((*MockTLSExtensionHandler)(nil).GetPeerParams))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Receive mocks base method
|
|
||||||
func (m *MockTLSExtensionHandler) Receive(arg0 mint.HandshakeType, arg1 *mint.ExtensionList) error {
|
|
||||||
ret := m.ctrl.Call(m, "Receive", arg0, arg1)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Receive indicates an expected call of Receive
|
|
||||||
func (mr *MockTLSExtensionHandlerMockRecorder) Receive(arg0, arg1 interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Receive), arg0, arg1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send mocks base method
|
|
||||||
func (m *MockTLSExtensionHandler) Send(arg0 mint.HandshakeType, arg1 *mint.ExtensionList) error {
|
|
||||||
ret := m.ctrl.Call(m, "Send", arg0, arg1)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send indicates an expected call of Send
|
|
||||||
func (mr *MockTLSExtensionHandlerMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call {
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Send), arg0, arg1)
|
|
||||||
}
|
|
24
vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go
generated
vendored
24
vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go
generated
vendored
|
@ -7,22 +7,22 @@ type EncryptionLevel int
|
||||||
const (
|
const (
|
||||||
// EncryptionUnspecified is a not specified encryption level
|
// EncryptionUnspecified is a not specified encryption level
|
||||||
EncryptionUnspecified EncryptionLevel = iota
|
EncryptionUnspecified EncryptionLevel = iota
|
||||||
// EncryptionUnencrypted is not encrypted
|
// EncryptionInitial is the Initial encryption level
|
||||||
EncryptionUnencrypted
|
EncryptionInitial
|
||||||
// EncryptionSecure is encrypted, but not forward secure
|
// EncryptionHandshake is the Handshake encryption level
|
||||||
EncryptionSecure
|
EncryptionHandshake
|
||||||
// EncryptionForwardSecure is forward secure
|
// Encryption1RTT is the 1-RTT encryption level
|
||||||
EncryptionForwardSecure
|
Encryption1RTT
|
||||||
)
|
)
|
||||||
|
|
||||||
func (e EncryptionLevel) String() string {
|
func (e EncryptionLevel) String() string {
|
||||||
switch e {
|
switch e {
|
||||||
case EncryptionUnencrypted:
|
case EncryptionInitial:
|
||||||
return "unencrypted"
|
return "Initial"
|
||||||
case EncryptionSecure:
|
case EncryptionHandshake:
|
||||||
return "encrypted (not forward-secure)"
|
return "Handshake"
|
||||||
case EncryptionForwardSecure:
|
case Encryption1RTT:
|
||||||
return "forward-secure"
|
return "1-RTT"
|
||||||
}
|
}
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,6 @@ func InferPacketNumber(
|
||||||
version VersionNumber,
|
version VersionNumber,
|
||||||
) PacketNumber {
|
) PacketNumber {
|
||||||
var epochDelta PacketNumber
|
var epochDelta PacketNumber
|
||||||
if version.UsesVarintPacketNumbers() {
|
|
||||||
switch packetNumberLength {
|
switch packetNumberLength {
|
||||||
case PacketNumberLen1:
|
case PacketNumberLen1:
|
||||||
epochDelta = PacketNumber(1) << 7
|
epochDelta = PacketNumber(1) << 7
|
||||||
|
@ -17,9 +16,6 @@ func InferPacketNumber(
|
||||||
case PacketNumberLen4:
|
case PacketNumberLen4:
|
||||||
epochDelta = PacketNumber(1) << 30
|
epochDelta = PacketNumber(1) << 30
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
epochDelta = PacketNumber(1) << (uint8(packetNumberLength) * 8)
|
|
||||||
}
|
|
||||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||||
prevEpochBegin := epoch - epochDelta
|
prevEpochBegin := epoch - epochDelta
|
||||||
nextEpochBegin := epoch + epochDelta
|
nextEpochBegin := epoch + epochDelta
|
||||||
|
@ -48,8 +44,7 @@ func delta(a, b PacketNumber) PacketNumber {
|
||||||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
|
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
|
||||||
diff := uint64(packetNumber - leastUnacked)
|
diff := uint64(packetNumber - leastUnacked)
|
||||||
if version.UsesVarintPacketNumbers() && diff < (1<<(14-1)) ||
|
if diff < (1 << (14 - 1)) {
|
||||||
!version.UsesVarintPacketNumbers() && diff < (1<<(16-1)) {
|
|
||||||
return PacketNumberLen2
|
return PacketNumberLen2
|
||||||
}
|
}
|
||||||
return PacketNumberLen4
|
return PacketNumberLen4
|
||||||
|
@ -63,8 +58,5 @@ func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen {
|
||||||
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
|
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
|
||||||
return PacketNumberLen2
|
return PacketNumberLen2
|
||||||
}
|
}
|
||||||
if packetNumber < (1 << (uint8(PacketNumberLen4) * 8)) {
|
|
||||||
return PacketNumberLen4
|
return PacketNumberLen4
|
||||||
}
|
|
||||||
return PacketNumberLen6
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,9 @@ const MaxPacketSizeIPv4 = 1252
|
||||||
// MaxPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
|
// MaxPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
|
||||||
const MaxPacketSizeIPv6 = 1232
|
const MaxPacketSizeIPv6 = 1232
|
||||||
|
|
||||||
|
// MinStatelessResetSize is the minimum size of a stateless reset packet
|
||||||
|
const MinStatelessResetSize = 1 + 20 + 16
|
||||||
|
|
||||||
// NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet
|
// NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet
|
||||||
// This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames
|
// This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames
|
||||||
const NonForwardSecurePacketSizeReduction = 50
|
const NonForwardSecurePacketSizeReduction = 50
|
||||||
|
@ -24,38 +27,22 @@ const InitialCongestionWindow ByteCount = 32 * DefaultTCPMSS
|
||||||
// session queues for later until it sends a public reset.
|
// session queues for later until it sends a public reset.
|
||||||
const MaxUndecryptablePackets = 10
|
const MaxUndecryptablePackets = 10
|
||||||
|
|
||||||
// PublicResetTimeout is the time to wait before sending a Public Reset when receiving too many undecryptable packets during the handshake
|
|
||||||
// This timeout allows the Go scheduler to switch to the Go rountine that reads the crypto stream and to escalate the crypto
|
|
||||||
const PublicResetTimeout = 500 * time.Millisecond
|
|
||||||
|
|
||||||
// ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data
|
|
||||||
// This is the value that Google servers are using
|
|
||||||
const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB
|
|
||||||
|
|
||||||
// ReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data
|
|
||||||
// This is the value that Google servers are using
|
|
||||||
const ReceiveConnectionFlowControlWindow = (1 << 10) * 48 // 48 kB
|
|
||||||
|
|
||||||
// DefaultMaxReceiveStreamFlowControlWindowServer is the default maximum stream-level flow control window for receiving data, for the server
|
|
||||||
// This is the value that Google servers are using
|
|
||||||
const DefaultMaxReceiveStreamFlowControlWindowServer = 1 * (1 << 20) // 1 MB
|
|
||||||
|
|
||||||
// DefaultMaxReceiveConnectionFlowControlWindowServer is the default connection-level flow control window for receiving data, for the server
|
|
||||||
// This is the value that Google servers are using
|
|
||||||
const DefaultMaxReceiveConnectionFlowControlWindowServer = 1.5 * (1 << 20) // 1.5 MB
|
|
||||||
|
|
||||||
// DefaultMaxReceiveStreamFlowControlWindowClient is the default maximum stream-level flow control window for receiving data, for the client
|
|
||||||
// This is the value that Chromium is using
|
|
||||||
const DefaultMaxReceiveStreamFlowControlWindowClient = 6 * (1 << 20) // 6 MB
|
|
||||||
|
|
||||||
// DefaultMaxReceiveConnectionFlowControlWindowClient is the default connection-level flow control window for receiving data, for the client
|
|
||||||
// This is the value that Google servers are using
|
|
||||||
const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15 MB
|
|
||||||
|
|
||||||
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
|
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
|
||||||
// This is the value that Chromium is using
|
// This is the value that Chromium is using
|
||||||
const ConnectionFlowControlMultiplier = 1.5
|
const ConnectionFlowControlMultiplier = 1.5
|
||||||
|
|
||||||
|
// InitialMaxStreamData is the stream-level flow control window for receiving data
|
||||||
|
const InitialMaxStreamData = (1 << 10) * 512 // 512 kb
|
||||||
|
|
||||||
|
// InitialMaxData is the connection-level flow control window for receiving data
|
||||||
|
const InitialMaxData = ConnectionFlowControlMultiplier * InitialMaxStreamData
|
||||||
|
|
||||||
|
// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data, for the server
|
||||||
|
const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB
|
||||||
|
|
||||||
|
// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data, for the server
|
||||||
|
const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 12 MB
|
||||||
|
|
||||||
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
|
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
|
||||||
const WindowUpdateThreshold = 0.25
|
const WindowUpdateThreshold = 0.25
|
||||||
|
|
||||||
|
@ -65,12 +52,6 @@ const DefaultMaxIncomingStreams = 100
|
||||||
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
|
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
|
||||||
const DefaultMaxIncomingUniStreams = 100
|
const DefaultMaxIncomingUniStreams = 100
|
||||||
|
|
||||||
// MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used.
|
|
||||||
const MaxStreamsMultiplier = 1.1
|
|
||||||
|
|
||||||
// MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used.
|
|
||||||
const MaxStreamsMinimumIncrement = 10
|
|
||||||
|
|
||||||
// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed.
|
// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed.
|
||||||
const MaxSessionUnprocessedPackets = defaultMaxCongestionWindowPackets
|
const MaxSessionUnprocessedPackets = defaultMaxCongestionWindowPackets
|
||||||
|
|
||||||
|
@ -103,15 +84,9 @@ const MaxNonRetransmittableAcks = 19
|
||||||
// prevents DoS attacks against the streamFrameSorter
|
// prevents DoS attacks against the streamFrameSorter
|
||||||
const MaxStreamFrameSorterGaps = 1000
|
const MaxStreamFrameSorterGaps = 1000
|
||||||
|
|
||||||
// CryptoMaxParams is the upper limit for the number of parameters in a crypto message.
|
// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
|
||||||
// Value taken from Chrome.
|
// This limits the size of the ClientHello and Certificates that can be received.
|
||||||
const CryptoMaxParams = 128
|
const MaxCryptoStreamOffset = 16 * (1 << 10)
|
||||||
|
|
||||||
// CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message.
|
|
||||||
const CryptoParameterMaxLength = 4000
|
|
||||||
|
|
||||||
// EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX.
|
|
||||||
const EphermalKeyLifetime = time.Minute
|
|
||||||
|
|
||||||
// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout
|
// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout
|
||||||
const MinRemoteIdleTimeout = 5 * time.Second
|
const MinRemoteIdleTimeout = 5 * time.Second
|
||||||
|
@ -122,12 +97,9 @@ const DefaultIdleTimeout = 30 * time.Second
|
||||||
// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
|
// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
|
||||||
const DefaultHandshakeTimeout = 10 * time.Second
|
const DefaultHandshakeTimeout = 10 * time.Second
|
||||||
|
|
||||||
// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed
|
// RetiredConnectionIDDeleteTimeout is the time we keep closed sessions around in order to retransmit the CONNECTION_CLOSE.
|
||||||
// after this time all information about the old connection will be deleted
|
// after this time all information about the old connection will be deleted
|
||||||
const ClosedSessionDeleteTimeout = time.Minute
|
const RetiredConnectionIDDeleteTimeout = 5 * time.Second
|
||||||
|
|
||||||
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
|
|
||||||
const NumCachedCertificates = 128
|
|
||||||
|
|
||||||
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
|
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
|
||||||
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
|
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
|
||||||
|
@ -135,7 +107,7 @@ const NumCachedCertificates = 128
|
||||||
// 2. it reduces the head-of-line blocking, when a packet is lost
|
// 2. it reduces the head-of-line blocking, when a packet is lost
|
||||||
const MinStreamFrameSize ByteCount = 128
|
const MinStreamFrameSize ByteCount = 128
|
||||||
|
|
||||||
// MaxAckFrameSize is the maximum size for an (IETF QUIC) ACK frame that we write
|
// MaxAckFrameSize is the maximum size for an ACK frame that we write
|
||||||
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
|
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
|
||||||
// The MaxAckFrameSize should be large enough to encode many ACK range,
|
// The MaxAckFrameSize should be large enough to encode many ACK range,
|
||||||
// but must ensure that a maximum size ACK frame fits into one packet.
|
// but must ensure that a maximum size ACK frame fits into one packet.
|
||||||
|
@ -149,6 +121,3 @@ const MinPacingDelay time.Duration = 100 * time.Microsecond
|
||||||
// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
|
// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
|
||||||
// if no other value is configured.
|
// if no other value is configured.
|
||||||
const DefaultConnectionIDLength = 4
|
const DefaultConnectionIDLength = 4
|
||||||
|
|
||||||
// MaxRetries is the maximum number of Retries a client will do before failing the connection.
|
|
||||||
const MaxRetries = 3
|
|
|
@ -19,11 +19,9 @@ const (
|
||||||
PacketNumberLen2 PacketNumberLen = 2
|
PacketNumberLen2 PacketNumberLen = 2
|
||||||
// PacketNumberLen4 is a packet number length of 4 bytes
|
// PacketNumberLen4 is a packet number length of 4 bytes
|
||||||
PacketNumberLen4 PacketNumberLen = 4
|
PacketNumberLen4 PacketNumberLen = 4
|
||||||
// PacketNumberLen6 is a packet number length of 6 bytes
|
|
||||||
PacketNumberLen6 PacketNumberLen = 6
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// The PacketType is the Long Header Type (only used for the IETF draft header format)
|
// The PacketType is the Long Header Type
|
||||||
type PacketType uint8
|
type PacketType uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -71,10 +69,7 @@ const MaxReceivePacketSize ByteCount = 1452 - 64
|
||||||
// Used in QUIC for congestion window computations in bytes.
|
// Used in QUIC for congestion window computations in bytes.
|
||||||
const DefaultTCPMSS ByteCount = 1460
|
const DefaultTCPMSS ByteCount = 1460
|
||||||
|
|
||||||
// MinClientHelloSize is the minimum size the server expects an inchoate CHLO to have (in gQUIC)
|
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
|
||||||
const MinClientHelloSize = 1024
|
|
||||||
|
|
||||||
// MinInitialPacketSize is the minimum size an Initial packet (in IETF QUIC) is required to have.
|
|
||||||
const MinInitialPacketSize = 1200
|
const MinInitialPacketSize = 1200
|
||||||
|
|
||||||
// MaxClientHellos is the maximum number of times we'll send a client hello
|
// MaxClientHellos is the maximum number of times we'll send a client hello
|
||||||
|
@ -83,8 +78,5 @@ const MinInitialPacketSize = 1200
|
||||||
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
|
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
|
||||||
const MaxClientHellos = 3
|
const MaxClientHellos = 3
|
||||||
|
|
||||||
// ConnectionIDLenGQUIC is the length of the source Connection ID used on gQUIC QUIC packets.
|
|
||||||
const ConnectionIDLenGQUIC = 8
|
|
||||||
|
|
||||||
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
|
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
|
||||||
const MinConnectionIDLenInitial = 8
|
const MinConnectionIDLenInitial = 8
|
||||||
|
|
|
@ -3,34 +3,65 @@ package protocol
|
||||||
// A StreamID in QUIC
|
// A StreamID in QUIC
|
||||||
type StreamID uint64
|
type StreamID uint64
|
||||||
|
|
||||||
// MaxBidiStreamID is the highest stream ID that the peer is allowed to open,
|
// StreamType encodes if this is a unidirectional or bidirectional stream
|
||||||
// when it is allowed to open numStreams bidirectional streams.
|
type StreamType uint8
|
||||||
// It is only valid for IETF QUIC.
|
|
||||||
func MaxBidiStreamID(numStreams int, pers Perspective) StreamID {
|
const (
|
||||||
|
// StreamTypeUni is a unidirectional stream
|
||||||
|
StreamTypeUni StreamType = iota
|
||||||
|
// StreamTypeBidi is a bidirectional stream
|
||||||
|
StreamTypeBidi
|
||||||
|
)
|
||||||
|
|
||||||
|
// InitiatedBy says if the stream was initiated by the client or by the server
|
||||||
|
func (s StreamID) InitiatedBy() Perspective {
|
||||||
|
if s%2 == 0 {
|
||||||
|
return PerspectiveClient
|
||||||
|
}
|
||||||
|
return PerspectiveServer
|
||||||
|
}
|
||||||
|
|
||||||
|
//Type says if this is a unidirectional or bidirectional stream
|
||||||
|
func (s StreamID) Type() StreamType {
|
||||||
|
if s%4 >= 2 {
|
||||||
|
return StreamTypeUni
|
||||||
|
}
|
||||||
|
return StreamTypeBidi
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamNum returns how many streams in total are below this
|
||||||
|
// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9)
|
||||||
|
func (s StreamID) StreamNum() uint64 {
|
||||||
|
return uint64(s/4) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxStreamID is the highest stream ID that a peer is allowed to open,
|
||||||
|
// when it is allowed to open numStreams.
|
||||||
|
func MaxStreamID(stype StreamType, numStreams uint64, pers Perspective) StreamID {
|
||||||
if numStreams == 0 {
|
if numStreams == 0 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
var first StreamID
|
var first StreamID
|
||||||
if pers == PerspectiveClient {
|
switch stype {
|
||||||
|
case StreamTypeBidi:
|
||||||
|
switch pers {
|
||||||
|
case PerspectiveClient:
|
||||||
|
first = 0
|
||||||
|
case PerspectiveServer:
|
||||||
first = 1
|
first = 1
|
||||||
} else {
|
}
|
||||||
first = 4
|
case StreamTypeUni:
|
||||||
|
switch pers {
|
||||||
|
case PerspectiveClient:
|
||||||
|
first = 2
|
||||||
|
case PerspectiveServer:
|
||||||
|
first = 3
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return first + 4*StreamID(numStreams-1)
|
return first + 4*StreamID(numStreams-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaxUniStreamID is the highest stream ID that the peer is allowed to open,
|
// FirstStream returns the first valid stream ID
|
||||||
// when it is allowed to open numStreams unidirectional streams.
|
func FirstStream(stype StreamType, pers Perspective) StreamID {
|
||||||
// It is only valid for IETF QUIC.
|
return MaxStreamID(stype, 1, pers)
|
||||||
func MaxUniStreamID(numStreams int, pers Perspective) StreamID {
|
|
||||||
if numStreams == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
var first StreamID
|
|
||||||
if pers == PerspectiveClient {
|
|
||||||
first = 3
|
|
||||||
} else {
|
|
||||||
first = 2
|
|
||||||
}
|
|
||||||
return first + 4*StreamID(numStreams-1)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,32 +18,18 @@ const (
|
||||||
|
|
||||||
// The version numbers, making grepping easier
|
// The version numbers, making grepping easier
|
||||||
const (
|
const (
|
||||||
Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9
|
|
||||||
Version43 VersionNumber = gquicVersion0 + 4*0x100 + 0x3
|
|
||||||
Version44 VersionNumber = gquicVersion0 + 4*0x100 + 0x4
|
|
||||||
VersionTLS VersionNumber = 101
|
VersionTLS VersionNumber = 101
|
||||||
VersionWhatever VersionNumber = 0 // for when the version doesn't matter
|
VersionWhatever VersionNumber = 1 // for when the version doesn't matter
|
||||||
VersionUnknown VersionNumber = math.MaxUint32
|
VersionUnknown VersionNumber = math.MaxUint32
|
||||||
|
|
||||||
VersionMilestone0_10_0 VersionNumber = 0x51474f02
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SupportedVersions lists the versions that the server supports
|
// SupportedVersions lists the versions that the server supports
|
||||||
// must be in sorted descending order
|
// must be in sorted descending order
|
||||||
var SupportedVersions = []VersionNumber{
|
var SupportedVersions = []VersionNumber{VersionTLS}
|
||||||
Version44,
|
|
||||||
Version43,
|
|
||||||
Version39,
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValidVersion says if the version is known to quic-go
|
// IsValidVersion says if the version is known to quic-go
|
||||||
func IsValidVersion(v VersionNumber) bool {
|
func IsValidVersion(v VersionNumber) bool {
|
||||||
return v == VersionTLS || v == VersionMilestone0_10_0 || IsSupportedVersion(SupportedVersions, v)
|
return v == VersionTLS || IsSupportedVersion(SupportedVersions, v)
|
||||||
}
|
|
||||||
|
|
||||||
// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake
|
|
||||||
func (vn VersionNumber) UsesTLS() bool {
|
|
||||||
return !vn.isGQUIC()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vn VersionNumber) String() string {
|
func (vn VersionNumber) String() string {
|
||||||
|
@ -52,8 +38,6 @@ func (vn VersionNumber) String() string {
|
||||||
return "whatever"
|
return "whatever"
|
||||||
case VersionUnknown:
|
case VersionUnknown:
|
||||||
return "unknown"
|
return "unknown"
|
||||||
case VersionMilestone0_10_0:
|
|
||||||
return "quic-go Milestone 0.10.0"
|
|
||||||
case VersionTLS:
|
case VersionTLS:
|
||||||
return "TLS dev version (WIP)"
|
return "TLS dev version (WIP)"
|
||||||
default:
|
default:
|
||||||
|
@ -66,61 +50,9 @@ func (vn VersionNumber) String() string {
|
||||||
|
|
||||||
// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters
|
// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters
|
||||||
func (vn VersionNumber) ToAltSvc() string {
|
func (vn VersionNumber) ToAltSvc() string {
|
||||||
if vn.isGQUIC() {
|
|
||||||
return fmt.Sprintf("%d", vn.toGQUICVersion())
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%d", vn)
|
return fmt.Sprintf("%d", vn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CryptoStreamID gets the Stream ID of the crypto stream
|
|
||||||
func (vn VersionNumber) CryptoStreamID() StreamID {
|
|
||||||
if vn.isGQUIC() {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsesIETFFrameFormat tells if this version uses the IETF frame format
|
|
||||||
func (vn VersionNumber) UsesIETFFrameFormat() bool {
|
|
||||||
return !vn.isGQUIC()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsesIETFHeaderFormat tells if this version uses the IETF header format
|
|
||||||
func (vn VersionNumber) UsesIETFHeaderFormat() bool {
|
|
||||||
return !vn.isGQUIC() || vn >= Version44
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsesLengthInHeader tells if this version uses the Length field in the IETF header
|
|
||||||
func (vn VersionNumber) UsesLengthInHeader() bool {
|
|
||||||
return !vn.isGQUIC()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsesTokenInHeader tells if this version uses the Token field in the IETF header
|
|
||||||
func (vn VersionNumber) UsesTokenInHeader() bool {
|
|
||||||
return !vn.isGQUIC()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsesStopWaitingFrames tells if this version uses STOP_WAITING frames
|
|
||||||
func (vn VersionNumber) UsesStopWaitingFrames() bool {
|
|
||||||
return vn.isGQUIC() && vn <= Version43
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsesVarintPacketNumbers tells if this version uses 7/14/30 bit packet numbers
|
|
||||||
func (vn VersionNumber) UsesVarintPacketNumbers() bool {
|
|
||||||
return !vn.isGQUIC()
|
|
||||||
}
|
|
||||||
|
|
||||||
// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control
|
|
||||||
func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool {
|
|
||||||
if id == vn.CryptoStreamID() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if vn.isGQUIC() && id == 3 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vn VersionNumber) isGQUIC() bool {
|
func (vn VersionNumber) isGQUIC() bool {
|
||||||
return vn > gquicVersion0 && vn <= maxGquicVersion
|
return vn > gquicVersion0 && vn <= maxGquicVersion
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrorCode can be used as a normal error without reason.
|
// ErrorCode can be used as a normal error without reason.
|
||||||
type ErrorCode uint32
|
type ErrorCode uint16
|
||||||
|
|
||||||
func (e ErrorCode) Error() string {
|
func (e ErrorCode) Error() string {
|
||||||
return e.String()
|
return e.String()
|
|
@ -13,13 +13,6 @@ type ByteOrder interface {
|
||||||
ReadUint16(io.ByteReader) (uint16, error)
|
ReadUint16(io.ByteReader) (uint16, error)
|
||||||
|
|
||||||
WriteUint64(*bytes.Buffer, uint64)
|
WriteUint64(*bytes.Buffer, uint64)
|
||||||
WriteUint56(*bytes.Buffer, uint64)
|
|
||||||
WriteUint48(*bytes.Buffer, uint64)
|
|
||||||
WriteUint40(*bytes.Buffer, uint64)
|
|
||||||
WriteUint32(*bytes.Buffer, uint32)
|
WriteUint32(*bytes.Buffer, uint32)
|
||||||
WriteUint24(*bytes.Buffer, uint32)
|
|
||||||
WriteUint16(*bytes.Buffer, uint16)
|
WriteUint16(*bytes.Buffer, uint16)
|
||||||
|
|
||||||
ReadUfloat16(io.ByteReader) (uint64, error)
|
|
||||||
WriteUfloat16(*bytes.Buffer, uint64)
|
|
||||||
}
|
}
|
||||||
|
|
50
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
50
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
|
@ -2,7 +2,6 @@ package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -97,61 +96,12 @@ func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteUint56 writes 56 bit of a uint64
|
|
||||||
func (bigEndian) WriteUint56(b *bytes.Buffer, i uint64) {
|
|
||||||
if i >= (1 << 56) {
|
|
||||||
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
|
|
||||||
}
|
|
||||||
b.Write([]byte{
|
|
||||||
uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
|
|
||||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint48 writes 48 bit of a uint64
|
|
||||||
func (bigEndian) WriteUint48(b *bytes.Buffer, i uint64) {
|
|
||||||
if i >= (1 << 48) {
|
|
||||||
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
|
|
||||||
}
|
|
||||||
b.Write([]byte{
|
|
||||||
uint8(i >> 40), uint8(i >> 32),
|
|
||||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint40 writes 40 bit of a uint64
|
|
||||||
func (bigEndian) WriteUint40(b *bytes.Buffer, i uint64) {
|
|
||||||
if i >= (1 << 40) {
|
|
||||||
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
|
|
||||||
}
|
|
||||||
b.Write([]byte{
|
|
||||||
uint8(i >> 32),
|
|
||||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint32 writes a uint32
|
// WriteUint32 writes a uint32
|
||||||
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
|
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
|
||||||
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteUint24 writes 24 bit of a uint32
|
|
||||||
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
|
|
||||||
if i >= (1 << 24) {
|
|
||||||
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
|
|
||||||
}
|
|
||||||
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint16 writes a uint16
|
// WriteUint16 writes a uint16
|
||||||
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
|
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
|
||||||
b.Write([]byte{uint8(i >> 8), uint8(i)})
|
b.Write([]byte{uint8(i >> 8), uint8(i)})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l bigEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
|
|
||||||
return readUfloat16(b, l)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l bigEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
|
|
||||||
writeUfloat16(b, l, val)
|
|
||||||
}
|
|
||||||
|
|
157
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_little_endian.go
generated
vendored
157
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_little_endian.go
generated
vendored
|
@ -1,157 +0,0 @@
|
||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LittleEndian is the little-endian implementation of ByteOrder.
|
|
||||||
var LittleEndian ByteOrder = littleEndian{}
|
|
||||||
|
|
||||||
type littleEndian struct{}
|
|
||||||
|
|
||||||
var _ ByteOrder = &littleEndian{}
|
|
||||||
|
|
||||||
// ReadUintN reads N bytes
|
|
||||||
func (littleEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
|
|
||||||
var res uint64
|
|
||||||
for i := uint8(0); i < length; i++ {
|
|
||||||
bt, err := b.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
res ^= uint64(bt) << (i * 8)
|
|
||||||
}
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint64 reads a uint64
|
|
||||||
func (littleEndian) ReadUint64(b io.ByteReader) (uint64, error) {
|
|
||||||
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
|
|
||||||
var err error
|
|
||||||
if b1, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b2, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b3, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b4, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b5, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b6, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b7, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b8, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint32 reads a uint32
|
|
||||||
func (littleEndian) ReadUint32(b io.ByteReader) (uint32, error) {
|
|
||||||
var b1, b2, b3, b4 uint8
|
|
||||||
var err error
|
|
||||||
if b1, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b2, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b3, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b4, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint16 reads a uint16
|
|
||||||
func (littleEndian) ReadUint16(b io.ByteReader) (uint16, error) {
|
|
||||||
var b1, b2 uint8
|
|
||||||
var err error
|
|
||||||
if b1, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if b2, err = b.ReadByte(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return uint16(b1) + uint16(b2)<<8, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint64 writes a uint64
|
|
||||||
func (littleEndian) WriteUint64(b *bytes.Buffer, i uint64) {
|
|
||||||
b.Write([]byte{
|
|
||||||
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
|
|
||||||
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), uint8(i >> 56),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint56 writes 56 bit of a uint64
|
|
||||||
func (littleEndian) WriteUint56(b *bytes.Buffer, i uint64) {
|
|
||||||
if i >= (1 << 56) {
|
|
||||||
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
|
|
||||||
}
|
|
||||||
b.Write([]byte{
|
|
||||||
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
|
|
||||||
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint48 writes 48 bit of a uint64
|
|
||||||
func (littleEndian) WriteUint48(b *bytes.Buffer, i uint64) {
|
|
||||||
if i >= (1 << 48) {
|
|
||||||
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
|
|
||||||
}
|
|
||||||
b.Write([]byte{
|
|
||||||
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
|
|
||||||
uint8(i >> 32), uint8(i >> 40),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint40 writes 40 bit of a uint64
|
|
||||||
func (littleEndian) WriteUint40(b *bytes.Buffer, i uint64) {
|
|
||||||
if i >= (1 << 40) {
|
|
||||||
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
|
|
||||||
}
|
|
||||||
b.Write([]byte{
|
|
||||||
uint8(i), uint8(i >> 8), uint8(i >> 16),
|
|
||||||
uint8(i >> 24), uint8(i >> 32),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint32 writes a uint32
|
|
||||||
func (littleEndian) WriteUint32(b *bytes.Buffer, i uint32) {
|
|
||||||
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24)})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint24 writes 24 bit of a uint32
|
|
||||||
func (littleEndian) WriteUint24(b *bytes.Buffer, i uint32) {
|
|
||||||
if i >= (1 << 24) {
|
|
||||||
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
|
|
||||||
}
|
|
||||||
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16)})
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteUint16 writes a uint16
|
|
||||||
func (littleEndian) WriteUint16(b *bytes.Buffer, i uint16) {
|
|
||||||
b.Write([]byte{uint8(i), uint8(i >> 8)})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l littleEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
|
|
||||||
return readUfloat16(b, l)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l littleEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
|
|
||||||
writeUfloat16(b, l, val)
|
|
||||||
}
|
|
|
@ -1,86 +0,0 @@
|
||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
)
|
|
||||||
|
|
||||||
// We define an unsigned 16-bit floating point value, inspired by IEEE floats
|
|
||||||
// (http://en.wikipedia.org/wiki/Half_precision_floating-point_format),
|
|
||||||
// with 5-bit exponent (bias 1), 11-bit mantissa (effective 12 with hidden
|
|
||||||
// bit) and denormals, but without signs, transfinites or fractions. Wire format
|
|
||||||
// 16 bits (little-endian byte order) are split into exponent (high 5) and
|
|
||||||
// mantissa (low 11) and decoded as:
|
|
||||||
// uint64_t value;
|
|
||||||
// if (exponent == 0) value = mantissa;
|
|
||||||
// else value = (mantissa | 1 << 11) << (exponent - 1)
|
|
||||||
const uFloat16ExponentBits = 5
|
|
||||||
const uFloat16MaxExponent = (1 << uFloat16ExponentBits) - 2 // 30
|
|
||||||
const uFloat16MantissaBits = 16 - uFloat16ExponentBits // 11
|
|
||||||
const uFloat16MantissaEffectiveBits = uFloat16MantissaBits + 1 // 12
|
|
||||||
const uFloat16MaxValue = ((uint64(1) << uFloat16MantissaEffectiveBits) - 1) << uFloat16MaxExponent // 0x3FFC0000000
|
|
||||||
|
|
||||||
// readUfloat16 reads a float in the QUIC-float16 format and returns its uint64 representation
|
|
||||||
func readUfloat16(b io.ByteReader, byteOrder ByteOrder) (uint64, error) {
|
|
||||||
val, err := byteOrder.ReadUint16(b)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
res := uint64(val)
|
|
||||||
|
|
||||||
if res < (1 << uFloat16MantissaEffectiveBits) {
|
|
||||||
// Fast path: either the value is denormalized (no hidden bit), or
|
|
||||||
// normalized (hidden bit set, exponent offset by one) with exponent zero.
|
|
||||||
// Zero exponent offset by one sets the bit exactly where the hidden bit is.
|
|
||||||
// So in both cases the value encodes itself.
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
exponent := val >> uFloat16MantissaBits // No sign extend on uint!
|
|
||||||
// After the fast pass, the exponent is at least one (offset by one).
|
|
||||||
// Un-offset the exponent.
|
|
||||||
exponent--
|
|
||||||
// Here we need to clear the exponent and set the hidden bit. We have already
|
|
||||||
// decremented the exponent, so when we subtract it, it leaves behind the
|
|
||||||
// hidden bit.
|
|
||||||
res -= uint64(exponent) << uFloat16MantissaBits
|
|
||||||
res <<= exponent
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeUfloat16 writes a float in the QUIC-float16 format from its uint64 representation
|
|
||||||
func writeUfloat16(b *bytes.Buffer, byteOrder ByteOrder, value uint64) {
|
|
||||||
var result uint16
|
|
||||||
if value < (uint64(1) << uFloat16MantissaEffectiveBits) {
|
|
||||||
// Fast path: either the value is denormalized, or has exponent zero.
|
|
||||||
// Both cases are represented by the value itself.
|
|
||||||
result = uint16(value)
|
|
||||||
} else if value >= uFloat16MaxValue {
|
|
||||||
// Value is out of range; clamp it to the maximum representable.
|
|
||||||
result = math.MaxUint16
|
|
||||||
} else {
|
|
||||||
// The highest bit is between position 13 and 42 (zero-based), which
|
|
||||||
// corresponds to exponent 1-30. In the output, mantissa is from 0 to 10,
|
|
||||||
// hidden bit is 11 and exponent is 11 to 15. Shift the highest bit to 11
|
|
||||||
// and count the shifts.
|
|
||||||
exponent := uint16(0)
|
|
||||||
for offset := uint16(16); offset > 0; offset /= 2 {
|
|
||||||
// Right-shift the value until the highest bit is in position 11.
|
|
||||||
// For offset of 16, 8, 4, 2 and 1 (binary search over 1-30),
|
|
||||||
// shift if the bit is at or above 11 + offset.
|
|
||||||
if value >= (uint64(1) << (uFloat16MantissaBits + offset)) {
|
|
||||||
exponent += offset
|
|
||||||
value >>= offset
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hidden bit (position 11) is set. We should remove it and increment the
|
|
||||||
// exponent. Equivalently, we just add it to the exponent.
|
|
||||||
// This hides the bit.
|
|
||||||
result = (uint16(value) + (exponent << uFloat16MantissaBits))
|
|
||||||
}
|
|
||||||
|
|
||||||
byteOrder.WriteUint16(b, result)
|
|
||||||
}
|
|
|
@ -13,29 +13,21 @@ import (
|
||||||
// TODO: use the value sent in the transport parameters
|
// TODO: use the value sent in the transport parameters
|
||||||
const ackDelayExponent = 3
|
const ackDelayExponent = 3
|
||||||
|
|
||||||
|
var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
|
||||||
|
|
||||||
// An AckFrame is an ACK frame
|
// An AckFrame is an ACK frame
|
||||||
type AckFrame struct {
|
type AckFrame struct {
|
||||||
AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last
|
AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last
|
||||||
DelayTime time.Duration
|
DelayTime time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
|
|
||||||
return parseAckOrAckEcnFrame(r, false, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseAckEcnFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
|
|
||||||
return parseAckOrAckEcnFrame(r, true, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseAckFrame reads an ACK frame
|
// parseAckFrame reads an ACK frame
|
||||||
func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNumber) (*AckFrame, error) {
|
func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
|
||||||
if !version.UsesIETFFrameFormat() {
|
typeByte, err := r.ReadByte()
|
||||||
return parseAckFrameLegacy(r, version)
|
if err != nil {
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := r.ReadByte(); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
ecn := typeByte&0x1 > 0
|
||||||
|
|
||||||
frame := &AckFrame{}
|
frame := &AckFrame{}
|
||||||
|
|
||||||
|
@ -50,14 +42,6 @@ func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNu
|
||||||
}
|
}
|
||||||
frame.DelayTime = time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
|
frame.DelayTime = time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
|
||||||
|
|
||||||
if ecn {
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
if _, err := utils.ReadVarInt(r); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
numBlocks, err := utils.ReadVarInt(r)
|
numBlocks, err := utils.ReadVarInt(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -103,16 +87,22 @@ func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNu
|
||||||
if !frame.validateAckRanges() {
|
if !frame.validateAckRanges() {
|
||||||
return nil, errInvalidAckRanges
|
return nil, errInvalidAckRanges
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse (and skip) the ECN section
|
||||||
|
if ecn {
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if _, err := utils.ReadVarInt(r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return frame, nil
|
return frame, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes an ACK frame.
|
// Write writes an ACK frame.
|
||||||
func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||||
if !version.UsesIETFFrameFormat() {
|
b.WriteByte(0x2)
|
||||||
return f.writeLegacy(b, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.WriteByte(0x0d)
|
|
||||||
utils.WriteVarInt(b, uint64(f.LargestAcked()))
|
utils.WriteVarInt(b, uint64(f.LargestAcked()))
|
||||||
utils.WriteVarInt(b, encodeAckDelay(f.DelayTime))
|
utils.WriteVarInt(b, encodeAckDelay(f.DelayTime))
|
||||||
|
|
||||||
|
@ -134,10 +124,6 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||||
if !version.UsesIETFFrameFormat() {
|
|
||||||
return f.lengthLegacy(version)
|
|
||||||
}
|
|
||||||
|
|
||||||
largestAcked := f.AckRanges[0].Largest
|
largestAcked := f.AckRanges[0].Largest
|
||||||
numRanges := f.numEncodableAckRanges()
|
numRanges := f.numEncodableAckRanges()
|
||||||
|
|
||||||
|
|
|
@ -1,364 +0,0 @@
|
||||||
package wire
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
|
|
||||||
|
|
||||||
func parseAckFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (*AckFrame, error) {
|
|
||||||
frame := &AckFrame{}
|
|
||||||
|
|
||||||
typeByte, err := r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hasMissingRanges := typeByte&0x20 == 0x20
|
|
||||||
largestAckedLen := 2 * ((typeByte & 0x0C) >> 2)
|
|
||||||
if largestAckedLen == 0 {
|
|
||||||
largestAckedLen = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
missingSequenceNumberDeltaLen := 2 * (typeByte & 0x03)
|
|
||||||
if missingSequenceNumberDeltaLen == 0 {
|
|
||||||
missingSequenceNumberDeltaLen = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
la, err := utils.BigEndian.ReadUintN(r, largestAckedLen)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
largestAcked := protocol.PacketNumber(la)
|
|
||||||
|
|
||||||
delay, err := utils.BigEndian.ReadUfloat16(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.DelayTime = time.Duration(delay) * time.Microsecond
|
|
||||||
|
|
||||||
var numAckBlocks uint8
|
|
||||||
if hasMissingRanges {
|
|
||||||
numAckBlocks, err = r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasMissingRanges && numAckBlocks == 0 {
|
|
||||||
return nil, errInvalidAckRanges
|
|
||||||
}
|
|
||||||
|
|
||||||
abl, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ackBlockLength := protocol.PacketNumber(abl)
|
|
||||||
if largestAcked > 0 && ackBlockLength < 1 {
|
|
||||||
return nil, errors.New("invalid first ACK range")
|
|
||||||
}
|
|
||||||
|
|
||||||
if ackBlockLength > largestAcked+1 {
|
|
||||||
return nil, errInvalidAckRanges
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasMissingRanges {
|
|
||||||
ackRange := AckRange{
|
|
||||||
Smallest: largestAcked - ackBlockLength + 1,
|
|
||||||
Largest: largestAcked,
|
|
||||||
}
|
|
||||||
frame.AckRanges = append(frame.AckRanges, ackRange)
|
|
||||||
|
|
||||||
var inLongBlock bool
|
|
||||||
var lastRangeComplete bool
|
|
||||||
for i := uint8(0); i < numAckBlocks; i++ {
|
|
||||||
var gap uint8
|
|
||||||
gap, err = r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
abl, err := utils.BigEndian.ReadUintN(r, missingSequenceNumberDeltaLen)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ackBlockLength := protocol.PacketNumber(abl)
|
|
||||||
|
|
||||||
if inLongBlock {
|
|
||||||
frame.AckRanges[len(frame.AckRanges)-1].Smallest -= protocol.PacketNumber(gap) + ackBlockLength
|
|
||||||
frame.AckRanges[len(frame.AckRanges)-1].Largest -= protocol.PacketNumber(gap)
|
|
||||||
} else {
|
|
||||||
lastRangeComplete = false
|
|
||||||
ackRange := AckRange{
|
|
||||||
Largest: frame.AckRanges[len(frame.AckRanges)-1].Smallest - protocol.PacketNumber(gap) - 1,
|
|
||||||
}
|
|
||||||
ackRange.Smallest = ackRange.Largest - ackBlockLength + 1
|
|
||||||
frame.AckRanges = append(frame.AckRanges, ackRange)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ackBlockLength > 0 {
|
|
||||||
lastRangeComplete = true
|
|
||||||
}
|
|
||||||
inLongBlock = (ackBlockLength == 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the last range was not complete, First and Last make no sense
|
|
||||||
// remove the range from frame.AckRanges
|
|
||||||
if !lastRangeComplete {
|
|
||||||
frame.AckRanges = frame.AckRanges[:len(frame.AckRanges)-1]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
frame.AckRanges = make([]AckRange, 1)
|
|
||||||
if largestAcked != 0 {
|
|
||||||
frame.AckRanges[0].Largest = largestAcked
|
|
||||||
frame.AckRanges[0].Smallest = largestAcked + 1 - ackBlockLength
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !frame.validateAckRanges() {
|
|
||||||
return nil, errInvalidAckRanges
|
|
||||||
}
|
|
||||||
|
|
||||||
var numTimestamp byte
|
|
||||||
numTimestamp, err = r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if numTimestamp > 0 {
|
|
||||||
// Delta Largest acked
|
|
||||||
_, err = r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// First Timestamp
|
|
||||||
_, err = utils.BigEndian.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < int(numTimestamp)-1; i++ {
|
|
||||||
// Delta Largest acked
|
|
||||||
_, err = r.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Time Since Previous Timestamp
|
|
||||||
_, err = utils.BigEndian.ReadUint16(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return frame, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *AckFrame) writeLegacy(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
|
||||||
largestAcked := f.LargestAcked()
|
|
||||||
largestAckedLen := protocol.GetPacketNumberLength(largestAcked)
|
|
||||||
|
|
||||||
typeByte := uint8(0x40)
|
|
||||||
|
|
||||||
if largestAckedLen != protocol.PacketNumberLen1 {
|
|
||||||
typeByte ^= (uint8(largestAckedLen / 2)) << 2
|
|
||||||
}
|
|
||||||
|
|
||||||
missingSequenceNumberDeltaLen := f.getMissingSequenceNumberDeltaLen()
|
|
||||||
if missingSequenceNumberDeltaLen != protocol.PacketNumberLen1 {
|
|
||||||
typeByte ^= (uint8(missingSequenceNumberDeltaLen / 2))
|
|
||||||
}
|
|
||||||
|
|
||||||
if f.HasMissingRanges() {
|
|
||||||
typeByte |= 0x20
|
|
||||||
}
|
|
||||||
|
|
||||||
b.WriteByte(typeByte)
|
|
||||||
|
|
||||||
switch largestAckedLen {
|
|
||||||
case protocol.PacketNumberLen1:
|
|
||||||
b.WriteByte(uint8(largestAcked))
|
|
||||||
case protocol.PacketNumberLen2:
|
|
||||||
utils.BigEndian.WriteUint16(b, uint16(largestAcked))
|
|
||||||
case protocol.PacketNumberLen4:
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(largestAcked))
|
|
||||||
case protocol.PacketNumberLen6:
|
|
||||||
utils.BigEndian.WriteUint48(b, uint64(largestAcked)&(1<<48-1))
|
|
||||||
}
|
|
||||||
|
|
||||||
utils.BigEndian.WriteUfloat16(b, uint64(f.DelayTime/time.Microsecond))
|
|
||||||
|
|
||||||
var numRanges uint64
|
|
||||||
var numRangesWritten uint64
|
|
||||||
if f.HasMissingRanges() {
|
|
||||||
numRanges = f.numWritableNackRanges()
|
|
||||||
if numRanges > 0xFF {
|
|
||||||
panic("AckFrame: Too many ACK ranges")
|
|
||||||
}
|
|
||||||
b.WriteByte(uint8(numRanges - 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
var firstAckBlockLength protocol.PacketNumber
|
|
||||||
if !f.HasMissingRanges() {
|
|
||||||
firstAckBlockLength = largestAcked - f.LowestAcked() + 1
|
|
||||||
} else {
|
|
||||||
firstAckBlockLength = largestAcked - f.AckRanges[0].Smallest + 1
|
|
||||||
numRangesWritten++
|
|
||||||
}
|
|
||||||
|
|
||||||
switch missingSequenceNumberDeltaLen {
|
|
||||||
case protocol.PacketNumberLen1:
|
|
||||||
b.WriteByte(uint8(firstAckBlockLength))
|
|
||||||
case protocol.PacketNumberLen2:
|
|
||||||
utils.BigEndian.WriteUint16(b, uint16(firstAckBlockLength))
|
|
||||||
case protocol.PacketNumberLen4:
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(firstAckBlockLength))
|
|
||||||
case protocol.PacketNumberLen6:
|
|
||||||
utils.BigEndian.WriteUint48(b, uint64(firstAckBlockLength)&(1<<48-1))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, ackRange := range f.AckRanges {
|
|
||||||
if i == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
length := ackRange.Largest - ackRange.Smallest + 1
|
|
||||||
gap := f.AckRanges[i-1].Smallest - ackRange.Largest - 1
|
|
||||||
|
|
||||||
num := gap/0xFF + 1
|
|
||||||
if gap%0xFF == 0 {
|
|
||||||
num--
|
|
||||||
}
|
|
||||||
|
|
||||||
if num == 1 {
|
|
||||||
b.WriteByte(uint8(gap))
|
|
||||||
switch missingSequenceNumberDeltaLen {
|
|
||||||
case protocol.PacketNumberLen1:
|
|
||||||
b.WriteByte(uint8(length))
|
|
||||||
case protocol.PacketNumberLen2:
|
|
||||||
utils.BigEndian.WriteUint16(b, uint16(length))
|
|
||||||
case protocol.PacketNumberLen4:
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(length))
|
|
||||||
case protocol.PacketNumberLen6:
|
|
||||||
utils.BigEndian.WriteUint48(b, uint64(length)&(1<<48-1))
|
|
||||||
}
|
|
||||||
numRangesWritten++
|
|
||||||
} else {
|
|
||||||
for i := 0; i < int(num); i++ {
|
|
||||||
var lengthWritten uint64
|
|
||||||
var gapWritten uint8
|
|
||||||
|
|
||||||
if i == int(num)-1 { // last block
|
|
||||||
lengthWritten = uint64(length)
|
|
||||||
gapWritten = uint8(1 + ((gap - 1) % 255))
|
|
||||||
} else {
|
|
||||||
lengthWritten = 0
|
|
||||||
gapWritten = 0xFF
|
|
||||||
}
|
|
||||||
|
|
||||||
b.WriteByte(gapWritten)
|
|
||||||
switch missingSequenceNumberDeltaLen {
|
|
||||||
case protocol.PacketNumberLen1:
|
|
||||||
b.WriteByte(uint8(lengthWritten))
|
|
||||||
case protocol.PacketNumberLen2:
|
|
||||||
utils.BigEndian.WriteUint16(b, uint16(lengthWritten))
|
|
||||||
case protocol.PacketNumberLen4:
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(lengthWritten))
|
|
||||||
case protocol.PacketNumberLen6:
|
|
||||||
utils.BigEndian.WriteUint48(b, lengthWritten&(1<<48-1))
|
|
||||||
}
|
|
||||||
|
|
||||||
numRangesWritten++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// this is needed if not all AckRanges can be written to the ACK frame (if there are more than 0xFF)
|
|
||||||
if numRangesWritten >= numRanges {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if numRanges != numRangesWritten {
|
|
||||||
return errors.New("BUG: Inconsistent number of ACK ranges written")
|
|
||||||
}
|
|
||||||
|
|
||||||
b.WriteByte(0) // no timestamps
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *AckFrame) lengthLegacy(_ protocol.VersionNumber) protocol.ByteCount {
|
|
||||||
length := protocol.ByteCount(1 + 2 + 1) // 1 TypeByte, 2 ACK delay time, 1 Num Timestamp
|
|
||||||
length += protocol.ByteCount(protocol.GetPacketNumberLength(f.LargestAcked()))
|
|
||||||
|
|
||||||
missingSequenceNumberDeltaLen := protocol.ByteCount(f.getMissingSequenceNumberDeltaLen())
|
|
||||||
|
|
||||||
if f.HasMissingRanges() {
|
|
||||||
length += (1 + missingSequenceNumberDeltaLen) * protocol.ByteCount(f.numWritableNackRanges())
|
|
||||||
} else {
|
|
||||||
length += missingSequenceNumberDeltaLen
|
|
||||||
}
|
|
||||||
// we don't write
|
|
||||||
return length
|
|
||||||
}
|
|
||||||
|
|
||||||
// numWritableNackRanges calculates the number of ACK blocks that are about to be written
|
|
||||||
// this number is different from len(f.AckRanges) for the case of long gaps (> 255 packets)
|
|
||||||
func (f *AckFrame) numWritableNackRanges() uint64 {
|
|
||||||
if len(f.AckRanges) == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
var numRanges uint64
|
|
||||||
for i, ackRange := range f.AckRanges {
|
|
||||||
if i == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
lastAckRange := f.AckRanges[i-1]
|
|
||||||
gap := lastAckRange.Smallest - ackRange.Largest - 1
|
|
||||||
rangeLength := 1 + uint64(gap)/0xFF
|
|
||||||
if uint64(gap)%0xFF == 0 {
|
|
||||||
rangeLength--
|
|
||||||
}
|
|
||||||
|
|
||||||
if numRanges+rangeLength < 0xFF {
|
|
||||||
numRanges += rangeLength
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return numRanges + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *AckFrame) getMissingSequenceNumberDeltaLen() protocol.PacketNumberLen {
|
|
||||||
var maxRangeLength protocol.PacketNumber
|
|
||||||
|
|
||||||
if f.HasMissingRanges() {
|
|
||||||
for _, ackRange := range f.AckRanges {
|
|
||||||
rangeLength := ackRange.Largest - ackRange.Smallest + 1
|
|
||||||
if rangeLength > maxRangeLength {
|
|
||||||
maxRangeLength = rangeLength
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
maxRangeLength = f.LargestAcked() - f.LowestAcked() + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxRangeLength <= 0xFF {
|
|
||||||
return protocol.PacketNumberLen1
|
|
||||||
}
|
|
||||||
if maxRangeLength <= 0xFFFF {
|
|
||||||
return protocol.PacketNumberLen2
|
|
||||||
}
|
|
||||||
if maxRangeLength <= 0xFFFFFFFF {
|
|
||||||
return protocol.PacketNumberLen4
|
|
||||||
}
|
|
||||||
|
|
||||||
return protocol.PacketNumberLen6
|
|
||||||
}
|
|
|
@ -1,45 +0,0 @@
|
||||||
package wire
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A BlockedFrame is a BLOCKED frame
|
|
||||||
type BlockedFrame struct {
|
|
||||||
Offset protocol.ByteCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseBlockedFrame parses a BLOCKED frame
|
|
||||||
func parseBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*BlockedFrame, error) {
|
|
||||||
if _, err := r.ReadByte(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
offset, err := utils.ReadVarInt(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &BlockedFrame{
|
|
||||||
Offset: protocol.ByteCount(offset),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *BlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
|
||||||
if !version.UsesIETFFrameFormat() {
|
|
||||||
return (&blockedFrameLegacy{}).Write(b, version)
|
|
||||||
}
|
|
||||||
typeByte := uint8(0x08)
|
|
||||||
b.WriteByte(typeByte)
|
|
||||||
utils.WriteVarInt(b, uint64(f.Offset))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Length of a written frame
|
|
||||||
func (f *BlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
|
||||||
if !version.UsesIETFFrameFormat() {
|
|
||||||
return 1 + 4
|
|
||||||
}
|
|
||||||
return 1 + utils.VarIntLen(uint64(f.Offset))
|
|
||||||
}
|
|
37
vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go
generated
vendored
37
vendor/github.com/lucas-clemente/quic-go/internal/wire/blocked_frame_legacy.go
generated
vendored
|
@ -1,37 +0,0 @@
|
||||||
package wire
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type blockedFrameLegacy struct {
|
|
||||||
StreamID protocol.StreamID
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseBlockedFrameLegacy parses a BLOCKED frame (in gQUIC format)
|
|
||||||
// The frame returned is
|
|
||||||
// * a STREAM_BLOCKED frame, if the BLOCKED applies to a stream
|
|
||||||
// * a BLOCKED frame, if the BLOCKED applies to the connection
|
|
||||||
func parseBlockedFrameLegacy(r *bytes.Reader, _ protocol.VersionNumber) (Frame, error) {
|
|
||||||
if _, err := r.ReadByte(); err != nil { // read the TypeByte
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
streamID, err := utils.BigEndian.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if streamID == 0 {
|
|
||||||
return &BlockedFrame{}, nil
|
|
||||||
}
|
|
||||||
return &StreamBlockedFrame{StreamID: protocol.StreamID(streamID)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//Write writes a BLOCKED frame
|
|
||||||
func (f *blockedFrameLegacy) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
|
||||||
b.WriteByte(0x05)
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(f.StreamID))
|
|
||||||
return nil
|
|
||||||
}
|
|
73
vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go
generated
vendored
73
vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go
generated
vendored
|
@ -2,52 +2,43 @@ package wire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// A ConnectionCloseFrame in QUIC
|
// A ConnectionCloseFrame is a CONNECTION_CLOSE frame
|
||||||
type ConnectionCloseFrame struct {
|
type ConnectionCloseFrame struct {
|
||||||
|
IsApplicationError bool
|
||||||
ErrorCode qerr.ErrorCode
|
ErrorCode qerr.ErrorCode
|
||||||
ReasonPhrase string
|
ReasonPhrase string
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseConnectionCloseFrame reads a CONNECTION_CLOSE frame
|
|
||||||
func parseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) {
|
func parseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber) (*ConnectionCloseFrame, error) {
|
||||||
if _, err := r.ReadByte(); err != nil { // read the TypeByte
|
typeByte, err := r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var errorCode qerr.ErrorCode
|
f := &ConnectionCloseFrame{IsApplicationError: typeByte == 0x1d}
|
||||||
var reasonPhraseLen uint64
|
|
||||||
if version.UsesIETFFrameFormat() {
|
|
||||||
ec, err := utils.BigEndian.ReadUint16(r)
|
ec, err := utils.BigEndian.ReadUint16(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
errorCode = qerr.ErrorCode(ec)
|
f.ErrorCode = qerr.ErrorCode(ec)
|
||||||
|
// read the Frame Type, if this is not an application error
|
||||||
|
if !f.IsApplicationError {
|
||||||
|
if _, err := utils.ReadVarInt(r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var reasonPhraseLen uint64
|
||||||
reasonPhraseLen, err = utils.ReadVarInt(r)
|
reasonPhraseLen, err = utils.ReadVarInt(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
ec, err := utils.BigEndian.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
errorCode = qerr.ErrorCode(ec)
|
|
||||||
length, err := utils.BigEndian.ReadUint16(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
reasonPhraseLen = uint64(length)
|
|
||||||
}
|
|
||||||
|
|
||||||
// shortcut to prevent the unnecessary allocation of dataLen bytes
|
// shortcut to prevent the unnecessary allocation of dataLen bytes
|
||||||
// if the dataLen is larger than the remaining length of the packet
|
// if the dataLen is larger than the remaining length of the packet
|
||||||
// reading the whole reason phrase would result in EOF when attempting to READ
|
// reading the whole reason phrase would result in EOF when attempting to READ
|
||||||
|
@ -60,37 +51,31 @@ func parseConnectionCloseFrame(r *bytes.Reader, version protocol.VersionNumber)
|
||||||
// this should never happen, since we already checked the reasonPhraseLen earlier
|
// this should never happen, since we already checked the reasonPhraseLen earlier
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
f.ReasonPhrase = string(reasonPhrase)
|
||||||
return &ConnectionCloseFrame{
|
return f, nil
|
||||||
ErrorCode: errorCode,
|
|
||||||
ReasonPhrase: string(reasonPhrase),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length of a written frame
|
// Length of a written frame
|
||||||
func (f *ConnectionCloseFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
func (f *ConnectionCloseFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||||
if version.UsesIETFFrameFormat() {
|
length := 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
|
||||||
return 1 + 2 + utils.VarIntLen(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
|
if !f.IsApplicationError {
|
||||||
|
length++ // for the frame type
|
||||||
}
|
}
|
||||||
return 1 + 4 + 2 + protocol.ByteCount(len(f.ReasonPhrase))
|
return length
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes an CONNECTION_CLOSE frame.
|
|
||||||
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||||
b.WriteByte(0x02)
|
if f.IsApplicationError {
|
||||||
|
b.WriteByte(0x1d)
|
||||||
if len(f.ReasonPhrase) > math.MaxUint16 {
|
|
||||||
return errors.New("ConnectionFrame: ReasonPhrase too long")
|
|
||||||
}
|
|
||||||
|
|
||||||
if version.UsesIETFFrameFormat() {
|
|
||||||
utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode))
|
|
||||||
utils.WriteVarInt(b, uint64(len(f.ReasonPhrase)))
|
|
||||||
} else {
|
} else {
|
||||||
utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode))
|
b.WriteByte(0x1c)
|
||||||
utils.BigEndian.WriteUint16(b, uint16(len(f.ReasonPhrase)))
|
|
||||||
}
|
}
|
||||||
b.WriteString(f.ReasonPhrase)
|
|
||||||
|
|
||||||
|
utils.BigEndian.WriteUint16(b, uint16(f.ErrorCode))
|
||||||
|
if !f.IsApplicationError {
|
||||||
|
utils.WriteVarInt(b, 0)
|
||||||
|
}
|
||||||
|
utils.WriteVarInt(b, uint64(len(f.ReasonPhrase)))
|
||||||
|
b.WriteString(f.ReasonPhrase)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
71
vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go
generated
vendored
Normal file
71
vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go
generated
vendored
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package wire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A CryptoFrame is a CRYPTO frame
|
||||||
|
type CryptoFrame struct {
|
||||||
|
Offset protocol.ByteCount
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) {
|
||||||
|
if _, err := r.ReadByte(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
frame := &CryptoFrame{}
|
||||||
|
offset, err := utils.ReadVarInt(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
frame.Offset = protocol.ByteCount(offset)
|
||||||
|
dataLen, err := utils.ReadVarInt(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if dataLen > uint64(r.Len()) {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
if dataLen != 0 {
|
||||||
|
frame.Data = make([]byte, dataLen)
|
||||||
|
if _, err := io.ReadFull(r, frame.Data); err != nil {
|
||||||
|
// this should never happen, since we already checked the dataLen earlier
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return frame, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *CryptoFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
||||||
|
b.WriteByte(0x6)
|
||||||
|
utils.WriteVarInt(b, uint64(f.Offset))
|
||||||
|
utils.WriteVarInt(b, uint64(len(f.Data)))
|
||||||
|
b.Write(f.Data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Length of a written frame
|
||||||
|
func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||||
|
return 1 + utils.VarIntLen(uint64(f.Offset)) + utils.VarIntLen(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxDataLen returns the maximum data length
|
||||||
|
func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount {
|
||||||
|
// pretend that the data size will be 1 bytes
|
||||||
|
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
|
||||||
|
headerLen := 1 + utils.VarIntLen(uint64(f.Offset)) + 1
|
||||||
|
if headerLen > maxSize {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
maxDataLen := maxSize - headerLen
|
||||||
|
if utils.VarIntLen(uint64(maxDataLen)) != 1 {
|
||||||
|
maxDataLen--
|
||||||
|
}
|
||||||
|
return maxDataLen
|
||||||
|
}
|
38
vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go
generated
vendored
Normal file
38
vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go
generated
vendored
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
package wire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A DataBlockedFrame is a DATA_BLOCKED frame
|
||||||
|
type DataBlockedFrame struct {
|
||||||
|
DataLimit protocol.ByteCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) {
|
||||||
|
if _, err := r.ReadByte(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
offset, err := utils.ReadVarInt(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &DataBlockedFrame{
|
||||||
|
DataLimit: protocol.ByteCount(offset),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
|
||||||
|
typeByte := uint8(0x14)
|
||||||
|
b.WriteByte(typeByte)
|
||||||
|
utils.WriteVarInt(b, uint64(f.DataLimit))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Length of a written frame
|
||||||
|
func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
||||||
|
return 1 + utils.VarIntLen(uint64(f.DataLimit))
|
||||||
|
}
|
|
@ -2,16 +2,15 @@ package wire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseNextFrame parses the next frame
|
// ParseNextFrame parses the next frame
|
||||||
// It skips PADDING frames.
|
// It skips PADDING frames.
|
||||||
func ParseNextFrame(r *bytes.Reader, hdr *Header, v protocol.VersionNumber) (Frame, error) {
|
func ParseNextFrame(r *bytes.Reader, v protocol.VersionNumber) (Frame, error) {
|
||||||
for r.Len() != 0 {
|
for r.Len() != 0 {
|
||||||
typeByte, _ := r.ReadByte()
|
typeByte, _ := r.ReadByte()
|
||||||
if typeByte == 0x0 { // PADDING frame
|
if typeByte == 0x0 { // PADDING frame
|
||||||
|
@ -19,154 +18,61 @@ func ParseNextFrame(r *bytes.Reader, hdr *Header, v protocol.VersionNumber) (Fra
|
||||||
}
|
}
|
||||||
r.UnreadByte()
|
r.UnreadByte()
|
||||||
|
|
||||||
if !v.UsesIETFFrameFormat() {
|
return parseFrame(r, typeByte, v)
|
||||||
return parseGQUICFrame(r, typeByte, hdr, v)
|
|
||||||
}
|
|
||||||
return parseIETFFrame(r, typeByte, v)
|
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseIETFFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) {
|
func parseFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (Frame, error) {
|
||||||
var frame Frame
|
var frame Frame
|
||||||
var err error
|
var err error
|
||||||
if typeByte&0xf8 == 0x10 {
|
if typeByte&0xf8 == 0x8 {
|
||||||
frame, err = parseStreamFrame(r, v)
|
frame, err = parseStreamFrame(r, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = qerr.Error(qerr.InvalidStreamData, err.Error())
|
return nil, qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||||
}
|
}
|
||||||
return frame, err
|
return frame, nil
|
||||||
}
|
}
|
||||||
// TODO: implement all IETF QUIC frame types
|
|
||||||
switch typeByte {
|
switch typeByte {
|
||||||
case 0x1:
|
case 0x1:
|
||||||
frame, err = parseRstStreamFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidRstStreamData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x2:
|
|
||||||
frame, err = parseConnectionCloseFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x4:
|
|
||||||
frame, err = parseMaxDataFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x5:
|
|
||||||
frame, err = parseMaxStreamDataFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x6:
|
|
||||||
frame, err = parseMaxStreamIDFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x7:
|
|
||||||
frame, err = parsePingFrame(r, v)
|
frame, err = parsePingFrame(r, v)
|
||||||
case 0x8:
|
case 0x2, 0x3:
|
||||||
frame, err = parseBlockedFrame(r, v)
|
frame, err = parseAckFrame(r, v)
|
||||||
if err != nil {
|
case 0x4:
|
||||||
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
|
frame, err = parseResetStreamFrame(r, v)
|
||||||
}
|
case 0x5:
|
||||||
case 0x9:
|
|
||||||
frame, err = parseStreamBlockedFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
|
|
||||||
}
|
|
||||||
case 0xa:
|
|
||||||
frame, err = parseStreamIDBlockedFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
|
||||||
}
|
|
||||||
case 0xc:
|
|
||||||
frame, err = parseStopSendingFrame(r, v)
|
frame, err = parseStopSendingFrame(r, v)
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
|
||||||
}
|
|
||||||
case 0xd:
|
|
||||||
frame, err = parseAckFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidAckData, err.Error())
|
|
||||||
}
|
|
||||||
case 0xe:
|
|
||||||
frame, err = parsePathChallengeFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
|
||||||
}
|
|
||||||
case 0xf:
|
|
||||||
frame, err = parsePathResponseFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidFrameData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x1a:
|
|
||||||
frame, err = parseAckEcnFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidAckData, err.Error())
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte))
|
|
||||||
}
|
|
||||||
return frame, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseGQUICFrame(r *bytes.Reader, typeByte byte, hdr *Header, v protocol.VersionNumber) (Frame, error) {
|
|
||||||
var frame Frame
|
|
||||||
var err error
|
|
||||||
if typeByte&0x80 == 0x80 {
|
|
||||||
frame, err = parseStreamFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidStreamData, err.Error())
|
|
||||||
}
|
|
||||||
return frame, err
|
|
||||||
} else if typeByte&0xc0 == 0x40 {
|
|
||||||
frame, err = parseAckFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidAckData, err.Error())
|
|
||||||
}
|
|
||||||
return frame, err
|
|
||||||
}
|
|
||||||
switch typeByte {
|
|
||||||
case 0x1:
|
|
||||||
frame, err = parseRstStreamFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidRstStreamData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x2:
|
|
||||||
frame, err = parseConnectionCloseFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidConnectionCloseData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x3:
|
|
||||||
frame, err = parseGoawayFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidGoawayData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x4:
|
|
||||||
frame, err = parseWindowUpdateFrame(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidWindowUpdateData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x5:
|
|
||||||
frame, err = parseBlockedFrameLegacy(r, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidBlockedData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x6:
|
case 0x6:
|
||||||
if !v.UsesStopWaitingFrames() {
|
frame, err = parseCryptoFrame(r, v)
|
||||||
err = errors.New("STOP_WAITING frames not supported by this QUIC version")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
frame, err = parseStopWaitingFrame(r, hdr.PacketNumber, hdr.PacketNumberLen, v)
|
|
||||||
if err != nil {
|
|
||||||
err = qerr.Error(qerr.InvalidStopWaitingData, err.Error())
|
|
||||||
}
|
|
||||||
case 0x7:
|
case 0x7:
|
||||||
frame, err = parsePingFrame(r, v)
|
frame, err = parseNewTokenFrame(r, v)
|
||||||
|
case 0x10:
|
||||||
|
frame, err = parseMaxDataFrame(r, v)
|
||||||
|
case 0x11:
|
||||||
|
frame, err = parseMaxStreamDataFrame(r, v)
|
||||||
|
case 0x12, 0x13:
|
||||||
|
frame, err = parseMaxStreamsFrame(r, v)
|
||||||
|
case 0x14:
|
||||||
|
frame, err = parseDataBlockedFrame(r, v)
|
||||||
|
case 0x15:
|
||||||
|
frame, err = parseStreamDataBlockedFrame(r, v)
|
||||||
|
case 0x16, 0x17:
|
||||||
|
frame, err = parseStreamsBlockedFrame(r, v)
|
||||||
|
case 0x18:
|
||||||
|
frame, err = parseNewConnectionIDFrame(r, v)
|
||||||
|
case 0x19:
|
||||||
|
frame, err = parseRetireConnectionIDFrame(r, v)
|
||||||
|
case 0x1a:
|
||||||
|
frame, err = parsePathChallengeFrame(r, v)
|
||||||
|
case 0x1b:
|
||||||
|
frame, err = parsePathResponseFrame(r, v)
|
||||||
|
case 0x1c, 0x1d:
|
||||||
|
frame, err = parseConnectionCloseFrame(r, v)
|
||||||
default:
|
default:
|
||||||
err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte))
|
err = fmt.Errorf("unknown type byte 0x%x", typeByte)
|
||||||
}
|
}
|
||||||
return frame, err
|
if err != nil {
|
||||||
|
return nil, qerr.Error(qerr.InvalidFrameData, err.Error())
|
||||||
|
}
|
||||||
|
return frame, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,68 +0,0 @@
|
||||||
package wire
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A GoawayFrame is a GOAWAY frame
|
|
||||||
type GoawayFrame struct {
|
|
||||||
ErrorCode qerr.ErrorCode
|
|
||||||
LastGoodStream protocol.StreamID
|
|
||||||
ReasonPhrase string
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseGoawayFrame parses a GOAWAY frame
|
|
||||||
func parseGoawayFrame(r *bytes.Reader, _ protocol.VersionNumber) (*GoawayFrame, error) {
|
|
||||||
frame := &GoawayFrame{}
|
|
||||||
|
|
||||||
if _, err := r.ReadByte(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
errorCode, err := utils.BigEndian.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.ErrorCode = qerr.ErrorCode(errorCode)
|
|
||||||
|
|
||||||
lastGoodStream, err := utils.BigEndian.ReadUint32(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.LastGoodStream = protocol.StreamID(lastGoodStream)
|
|
||||||
|
|
||||||
reasonPhraseLen, err := utils.BigEndian.ReadUint16(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if reasonPhraseLen > uint16(protocol.MaxReceivePacketSize) {
|
|
||||||
return nil, qerr.Error(qerr.InvalidGoawayData, "reason phrase too long")
|
|
||||||
}
|
|
||||||
|
|
||||||
reasonPhrase := make([]byte, reasonPhraseLen)
|
|
||||||
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
frame.ReasonPhrase = string(reasonPhrase)
|
|
||||||
return frame, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *GoawayFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
|
|
||||||
b.WriteByte(0x03)
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(f.ErrorCode))
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(f.LastGoodStream))
|
|
||||||
utils.BigEndian.WriteUint16(b, uint16(len(f.ReasonPhrase)))
|
|
||||||
b.WriteString(f.ReasonPhrase)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Length of a written frame
|
|
||||||
func (f *GoawayFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
|
|
||||||
return protocol.ByteCount(1 + 4 + 4 + 2 + len(f.ReasonPhrase))
|
|
||||||
}
|
|
|
@ -3,7 +3,6 @@ package wire
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
@ -11,10 +10,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// Header is the header of a QUIC packet.
|
// Header is the header of a QUIC packet.
|
||||||
// It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header.
|
|
||||||
type Header struct {
|
type Header struct {
|
||||||
IsPublicHeader bool
|
|
||||||
|
|
||||||
Raw []byte
|
Raw []byte
|
||||||
|
|
||||||
Version protocol.VersionNumber
|
Version protocol.VersionNumber
|
||||||
|
@ -29,12 +25,6 @@ type Header struct {
|
||||||
IsVersionNegotiation bool
|
IsVersionNegotiation bool
|
||||||
SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
|
SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
|
||||||
|
|
||||||
// only needed for the gQUIC Public Header
|
|
||||||
VersionFlag bool
|
|
||||||
ResetFlag bool
|
|
||||||
DiversificationNonce []byte
|
|
||||||
|
|
||||||
// only needed for the IETF Header
|
|
||||||
Type protocol.PacketType
|
Type protocol.PacketType
|
||||||
IsLongHeader bool
|
IsLongHeader bool
|
||||||
KeyPhase int
|
KeyPhase int
|
||||||
|
@ -42,15 +32,8 @@ type Header struct {
|
||||||
Token []byte
|
Token []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
var errInvalidPacketNumberLen = errors.New("invalid packet number length")
|
|
||||||
|
|
||||||
// Write writes the Header.
|
// Write writes the Header.
|
||||||
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
|
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
|
||||||
if !ver.UsesIETFHeaderFormat() {
|
|
||||||
h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
|
|
||||||
return h.writePublicHeader(b, pers, ver)
|
|
||||||
}
|
|
||||||
// write an IETF QUIC header
|
|
||||||
if h.IsLongHeader {
|
if h.IsLongHeader {
|
||||||
return h.writeLongHeader(b, ver)
|
return h.writeLongHeader(b, ver)
|
||||||
}
|
}
|
||||||
|
@ -69,7 +52,7 @@ func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) erro
|
||||||
b.Write(h.DestConnectionID.Bytes())
|
b.Write(h.DestConnectionID.Bytes())
|
||||||
b.Write(h.SrcConnectionID.Bytes())
|
b.Write(h.SrcConnectionID.Bytes())
|
||||||
|
|
||||||
if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
|
if h.Type == protocol.PacketTypeInitial {
|
||||||
utils.WriteVarInt(b, uint64(len(h.Token)))
|
utils.WriteVarInt(b, uint64(len(h.Token)))
|
||||||
b.Write(h.Token)
|
b.Write(h.Token)
|
||||||
}
|
}
|
||||||
|
@ -89,176 +72,36 @@ func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) erro
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if v.UsesLengthInHeader() {
|
|
||||||
utils.WriteVarInt(b, uint64(h.PayloadLen))
|
utils.WriteVarInt(b, uint64(h.PayloadLen))
|
||||||
}
|
|
||||||
if v.UsesVarintPacketNumbers() {
|
|
||||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||||
}
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
|
|
||||||
if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
|
|
||||||
if len(h.DiversificationNonce) != 32 {
|
|
||||||
return errors.New("invalid diversification nonce length")
|
|
||||||
}
|
|
||||||
b.Write(h.DiversificationNonce)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
||||||
typeByte := byte(0x30)
|
typeByte := byte(0x30)
|
||||||
typeByte |= byte(h.KeyPhase << 6)
|
typeByte |= byte(h.KeyPhase << 6)
|
||||||
if !v.UsesVarintPacketNumbers() {
|
|
||||||
switch h.PacketNumberLen {
|
|
||||||
case protocol.PacketNumberLen1:
|
|
||||||
case protocol.PacketNumberLen2:
|
|
||||||
typeByte |= 0x1
|
|
||||||
case protocol.PacketNumberLen4:
|
|
||||||
typeByte |= 0x2
|
|
||||||
default:
|
|
||||||
return errInvalidPacketNumberLen
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
b.WriteByte(typeByte)
|
b.WriteByte(typeByte)
|
||||||
b.Write(h.DestConnectionID.Bytes())
|
b.Write(h.DestConnectionID.Bytes())
|
||||||
|
|
||||||
if !v.UsesVarintPacketNumbers() {
|
|
||||||
switch h.PacketNumberLen {
|
|
||||||
case protocol.PacketNumberLen1:
|
|
||||||
b.WriteByte(uint8(h.PacketNumber))
|
|
||||||
case protocol.PacketNumberLen2:
|
|
||||||
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
|
|
||||||
case protocol.PacketNumberLen4:
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writePublicHeader writes a Public Header.
|
|
||||||
func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
|
|
||||||
if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) {
|
|
||||||
return errors.New("PublicHeader: Can only write regular packets")
|
|
||||||
}
|
|
||||||
if h.SrcConnectionID.Len() != 0 {
|
|
||||||
return errors.New("PublicHeader: SrcConnectionID must not be set")
|
|
||||||
}
|
|
||||||
if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
|
|
||||||
return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
|
|
||||||
}
|
|
||||||
|
|
||||||
publicFlagByte := uint8(0x00)
|
|
||||||
if h.VersionFlag {
|
|
||||||
publicFlagByte |= 0x01
|
|
||||||
}
|
|
||||||
if h.DestConnectionID.Len() > 0 {
|
|
||||||
publicFlagByte |= 0x08
|
|
||||||
}
|
|
||||||
if len(h.DiversificationNonce) > 0 {
|
|
||||||
if len(h.DiversificationNonce) != 32 {
|
|
||||||
return errors.New("invalid diversification nonce length")
|
|
||||||
}
|
|
||||||
publicFlagByte |= 0x04
|
|
||||||
}
|
|
||||||
switch h.PacketNumberLen {
|
|
||||||
case protocol.PacketNumberLen1:
|
|
||||||
publicFlagByte |= 0x00
|
|
||||||
case protocol.PacketNumberLen2:
|
|
||||||
publicFlagByte |= 0x10
|
|
||||||
case protocol.PacketNumberLen4:
|
|
||||||
publicFlagByte |= 0x20
|
|
||||||
}
|
|
||||||
b.WriteByte(publicFlagByte)
|
|
||||||
|
|
||||||
if h.DestConnectionID.Len() > 0 {
|
|
||||||
b.Write(h.DestConnectionID)
|
|
||||||
}
|
|
||||||
if h.VersionFlag && pers == protocol.PerspectiveClient {
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(h.Version))
|
|
||||||
}
|
|
||||||
if len(h.DiversificationNonce) > 0 {
|
|
||||||
b.Write(h.DiversificationNonce)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch h.PacketNumberLen {
|
|
||||||
case protocol.PacketNumberLen1:
|
|
||||||
b.WriteByte(uint8(h.PacketNumber))
|
|
||||||
case protocol.PacketNumberLen2:
|
|
||||||
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
|
|
||||||
case protocol.PacketNumberLen4:
|
|
||||||
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
|
|
||||||
case protocol.PacketNumberLen6:
|
|
||||||
return errInvalidPacketNumberLen
|
|
||||||
default:
|
|
||||||
return errors.New("PublicHeader: PacketNumberLen not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLength determines the length of the Header.
|
// GetLength determines the length of the Header.
|
||||||
func (h *Header) GetLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
|
func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount {
|
||||||
if !v.UsesIETFHeaderFormat() {
|
|
||||||
return h.getPublicHeaderLength()
|
|
||||||
}
|
|
||||||
return h.getHeaderLength(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Header) getHeaderLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
|
|
||||||
if h.IsLongHeader {
|
if h.IsLongHeader {
|
||||||
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen)
|
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.PayloadLen))
|
||||||
if v.UsesLengthInHeader() {
|
if h.Type == protocol.PacketTypeInitial {
|
||||||
length += utils.VarIntLen(uint64(h.PayloadLen))
|
|
||||||
}
|
|
||||||
if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
|
|
||||||
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
||||||
}
|
}
|
||||||
if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
|
return length
|
||||||
length += protocol.ByteCount(len(h.DiversificationNonce))
|
|
||||||
}
|
|
||||||
return length, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
|
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
|
||||||
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
|
|
||||||
return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
|
|
||||||
}
|
|
||||||
length += protocol.ByteCount(h.PacketNumberLen)
|
length += protocol.ByteCount(h.PacketNumberLen)
|
||||||
return length, nil
|
return length
|
||||||
}
|
|
||||||
|
|
||||||
// getPublicHeaderLength gets the length of the publicHeader in bytes.
|
|
||||||
// It can only be called for regular packets.
|
|
||||||
func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
|
|
||||||
length := protocol.ByteCount(1) // 1 byte for public flags
|
|
||||||
if h.PacketNumberLen == protocol.PacketNumberLen6 {
|
|
||||||
return 0, errInvalidPacketNumberLen
|
|
||||||
}
|
|
||||||
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
|
|
||||||
return 0, errPacketNumberLenNotSet
|
|
||||||
}
|
|
||||||
length += protocol.ByteCount(h.PacketNumberLen)
|
|
||||||
length += protocol.ByteCount(h.DestConnectionID.Len())
|
|
||||||
// Version Number in packets sent by the client
|
|
||||||
if h.VersionFlag {
|
|
||||||
length += 4
|
|
||||||
}
|
|
||||||
length += protocol.ByteCount(len(h.DiversificationNonce))
|
|
||||||
return length, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log logs the Header
|
// Log logs the Header
|
||||||
func (h *Header) Log(logger utils.Logger) {
|
func (h *Header) Log(logger utils.Logger) {
|
||||||
if h.IsPublicHeader {
|
|
||||||
h.logPublicHeader(logger)
|
|
||||||
} else {
|
|
||||||
h.logHeader(logger)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Header) logHeader(logger utils.Logger) {
|
|
||||||
if h.IsLongHeader {
|
if h.IsLongHeader {
|
||||||
if h.Version == 0 {
|
if h.Version == 0 {
|
||||||
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
|
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
|
||||||
|
@ -275,14 +118,6 @@ func (h *Header) logHeader(logger utils.Logger) {
|
||||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
|
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if h.Version == protocol.Version44 {
|
|
||||||
var divNonce string
|
|
||||||
if h.Type == protocol.PacketType0RTT {
|
|
||||||
divNonce = fmt.Sprintf("Diversification Nonce: %#x, ", h.DiversificationNonce)
|
|
||||||
}
|
|
||||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, divNonce, h.Version)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
|
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -290,14 +125,6 @@ func (h *Header) logHeader(logger utils.Logger) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Header) logPublicHeader(logger utils.Logger) {
|
|
||||||
ver := "(unset)"
|
|
||||||
if h.Version != 0 {
|
|
||||||
ver = h.Version.String()
|
|
||||||
}
|
|
||||||
logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
|
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
|
||||||
dcil, err := encodeSingleConnIDLen(dest)
|
dcil, err := encodeSingleConnIDLen(dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -6,8 +6,8 @@ import (
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// The InvariantHeader is the version independent part of the header
|
// The InvariantHeader is the version independent part of the header
|
||||||
|
@ -32,23 +32,11 @@ func ParseInvariantHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Invariant
|
||||||
|
|
||||||
// If this is not a Long Header, it could either be a Public Header or a Short Header.
|
// If this is not a Long Header, it could either be a Public Header or a Short Header.
|
||||||
if !h.IsLongHeader {
|
if !h.IsLongHeader {
|
||||||
// In the Public Header 0x8 is the Connection ID Flag.
|
var err error
|
||||||
// In the IETF Short Header:
|
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
|
||||||
// * 0x8 it is the gQUIC Demultiplexing bit, and always 0.
|
|
||||||
// * 0x20 and 0x10 are always 1.
|
|
||||||
var connIDLen int
|
|
||||||
if typeByte&0x8 > 0 { // Public Header containing a connection ID
|
|
||||||
connIDLen = 8
|
|
||||||
}
|
|
||||||
if typeByte&0x38 == 0x30 { // Short Header
|
|
||||||
connIDLen = shortHeaderConnIDLen
|
|
||||||
}
|
|
||||||
if connIDLen > 0 {
|
|
||||||
h.DestConnectionID, err = protocol.ReadConnectionID(b, connIDLen)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
// Long Header
|
// Long Header
|
||||||
|
@ -81,15 +69,6 @@ func (iv *InvariantHeader) Parse(b *bytes.Reader, sentBy protocol.Perspective, v
|
||||||
}
|
}
|
||||||
return iv.parseLongHeader(b, sentBy, ver)
|
return iv.parseLongHeader(b, sentBy, ver)
|
||||||
}
|
}
|
||||||
// The Public Header never uses 6 byte packet numbers.
|
|
||||||
// Therefore, the third and fourth bit will never be 11.
|
|
||||||
// For the Short Header, the third and fourth bit are always 11.
|
|
||||||
if iv.typeByte&0x30 != 0x30 {
|
|
||||||
if sentBy == protocol.PerspectiveServer && iv.typeByte&0x1 > 0 {
|
|
||||||
return iv.parseVersionNegotiationPacket(b)
|
|
||||||
}
|
|
||||||
return iv.parsePublicHeader(b, sentBy, ver)
|
|
||||||
}
|
|
||||||
return iv.parseShortHeader(b, ver)
|
return iv.parseShortHeader(b, ver)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,7 +83,6 @@ func (iv *InvariantHeader) toHeader() *Header {
|
||||||
|
|
||||||
func (iv *InvariantHeader) parseVersionNegotiationPacket(b *bytes.Reader) (*Header, error) {
|
func (iv *InvariantHeader) parseVersionNegotiationPacket(b *bytes.Reader) (*Header, error) {
|
||||||
h := iv.toHeader()
|
h := iv.toHeader()
|
||||||
h.VersionFlag = true
|
|
||||||
if b.Len() == 0 {
|
if b.Len() == 0 {
|
||||||
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
|
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
|
||||||
}
|
}
|
||||||
|
@ -145,7 +123,7 @@ func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader, sentBy protocol.Pers
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
|
if h.Type == protocol.PacketTypeInitial {
|
||||||
tokenLen, err := utils.ReadVarInt(b)
|
tokenLen, err := utils.ReadVarInt(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -159,37 +137,17 @@ func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader, sentBy protocol.Pers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if v.UsesLengthInHeader() {
|
|
||||||
pl, err := utils.ReadVarInt(b)
|
pl, err := utils.ReadVarInt(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
h.PayloadLen = protocol.ByteCount(pl)
|
h.PayloadLen = protocol.ByteCount(pl)
|
||||||
}
|
|
||||||
if v.UsesVarintPacketNumbers() {
|
|
||||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
h.PacketNumber = pn
|
h.PacketNumber = pn
|
||||||
h.PacketNumberLen = pnLen
|
h.PacketNumberLen = pnLen
|
||||||
} else {
|
|
||||||
pn, err := utils.BigEndian.ReadUint32(b)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.PacketNumber = protocol.PacketNumber(pn)
|
|
||||||
h.PacketNumberLen = protocol.PacketNumberLen4
|
|
||||||
}
|
|
||||||
if h.Type == protocol.PacketType0RTT && v == protocol.Version44 && sentBy == protocol.PerspectiveServer {
|
|
||||||
h.DiversificationNonce = make([]byte, 32)
|
|
||||||
if _, err := io.ReadFull(b, h.DiversificationNonce); err != nil {
|
|
||||||
if err == io.ErrUnexpectedEOF {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
|
@ -198,76 +156,12 @@ func (iv *InvariantHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionN
|
||||||
h := iv.toHeader()
|
h := iv.toHeader()
|
||||||
h.KeyPhase = int(iv.typeByte&0x40) >> 6
|
h.KeyPhase = int(iv.typeByte&0x40) >> 6
|
||||||
|
|
||||||
if v.UsesVarintPacketNumbers() {
|
|
||||||
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
h.PacketNumber = pn
|
h.PacketNumber = pn
|
||||||
h.PacketNumberLen = pnLen
|
h.PacketNumberLen = pnLen
|
||||||
} else {
|
|
||||||
switch iv.typeByte & 0x3 {
|
|
||||||
case 0x0:
|
|
||||||
h.PacketNumberLen = protocol.PacketNumberLen1
|
|
||||||
case 0x1:
|
|
||||||
h.PacketNumberLen = protocol.PacketNumberLen2
|
|
||||||
case 0x2:
|
|
||||||
h.PacketNumberLen = protocol.PacketNumberLen4
|
|
||||||
default:
|
|
||||||
return nil, errInvalidPacketNumberLen
|
|
||||||
}
|
|
||||||
p, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.PacketNumber = protocol.PacketNumber(p)
|
|
||||||
}
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (iv *InvariantHeader) parsePublicHeader(b *bytes.Reader, sentBy protocol.Perspective, ver protocol.VersionNumber) (*Header, error) {
|
|
||||||
h := iv.toHeader()
|
|
||||||
h.IsPublicHeader = true
|
|
||||||
h.ResetFlag = iv.typeByte&0x2 > 0
|
|
||||||
if h.ResetFlag {
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
h.VersionFlag = iv.typeByte&0x1 > 0
|
|
||||||
if h.VersionFlag && sentBy == protocol.PerspectiveClient {
|
|
||||||
v, err := utils.BigEndian.ReadUint32(b)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.Version = protocol.VersionNumber(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server.
|
|
||||||
// It doesn't have any meaning when sent by the client.
|
|
||||||
if sentBy == protocol.PerspectiveServer && iv.typeByte&0x4 > 0 {
|
|
||||||
h.DiversificationNonce = make([]byte, 32)
|
|
||||||
if _, err := io.ReadFull(b, h.DiversificationNonce); err != nil {
|
|
||||||
if err == io.ErrUnexpectedEOF {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch iv.typeByte & 0x30 {
|
|
||||||
case 0x00:
|
|
||||||
h.PacketNumberLen = protocol.PacketNumberLen1
|
|
||||||
case 0x10:
|
|
||||||
h.PacketNumberLen = protocol.PacketNumberLen2
|
|
||||||
case 0x20:
|
|
||||||
h.PacketNumberLen = protocol.PacketNumberLen4
|
|
||||||
}
|
|
||||||
|
|
||||||
pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.PacketNumber = protocol.PacketNumber(pn)
|
|
||||||
|
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,14 +18,11 @@ func LogFrame(logger utils.Logger, frame Frame, sent bool) {
|
||||||
dir = "->"
|
dir = "->"
|
||||||
}
|
}
|
||||||
switch f := frame.(type) {
|
switch f := frame.(type) {
|
||||||
|
case *CryptoFrame:
|
||||||
|
dataLen := protocol.ByteCount(len(f.Data))
|
||||||
|
logger.Debugf("\t%s &wire.CryptoFrame{Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.Offset, dataLen, f.Offset+dataLen)
|
||||||
case *StreamFrame:
|
case *StreamFrame:
|
||||||
logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
|
logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
|
||||||
case *StopWaitingFrame:
|
|
||||||
if sent {
|
|
||||||
logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
|
|
||||||
} else {
|
|
||||||
logger.Debugf("\t%s &wire.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
|
|
||||||
}
|
|
||||||
case *AckFrame:
|
case *AckFrame:
|
||||||
if len(f.AckRanges) > 1 {
|
if len(f.AckRanges) > 1 {
|
||||||
ackRanges := make([]string, len(f.AckRanges))
|
ackRanges := make([]string, len(f.AckRanges))
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue