vendor quic 44

pull/1435/head
Darien Raymond 2018-11-20 23:51:25 +01:00
parent bb8cab9cc7
commit 84f8bca01c
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
414 changed files with 78680 additions and 0 deletions

View File

@ -0,0 +1,44 @@
# Changelog
## v0.10.0 (2018-08-28)
- Add support for QUIC 44, drop support for QUIC 42.
## v0.9.0 (2018-08-15)
- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC).
- Split Session.Close into one method for regular closing and one for closing with an error.
## v0.8.0 (2018-06-26)
- Add support for unidirectional streams (for IETF QUIC).
- Add a `quic.Config` option for the maximum number of incoming streams.
- Add support for QUIC 42 and 43.
- Add dial functions that use a context.
- Multiplex clients on a net.PacketConn, when using Dial(conn).
## v0.7.0 (2018-02-03)
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
- Expose the `ConnectionState` in the `Session` (experimental API).
- Implement packet pacing.
## v0.6.0 (2017-12-12)
- Add support for QUIC 39, drop support for QUIC 35 - 37
- Added `quic.Config` options for maximal flow control windows
- Add a `quic.Config` option for QUIC versions
- Add a `quic.Config` option to request omission of the connection ID from a server
- Add a `quic.Config` option to configure the source address validation
- Add a `quic.Config` option to configure the handshake timeout
- Add a `quic.Config` option to configure the idle timeout
- Add a `quic.Config` option to configure keep-alive
- Rename the STK to Cookie
- Implement `net.Conn`-style deadlines for streams
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details.
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details.
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
- Changed `h2quic.Server.Serve()` to accept a `net.PacketConn`
- Drop support for Go 1.7 and 1.8.
- Various bugfixes

21
vendor/lucas-clemente/quic-go/LICENSE vendored Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2016 the quic-go authors & Google, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

73
vendor/lucas-clemente/quic-go/README.md vendored Normal file
View File

@ -0,0 +1,73 @@
# A QUIC implementation in pure Go
<img src="docs/quic.png" width=303 height=124>
[![Godoc Reference](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](https://godoc.org/github.com/lucas-clemente/quic-go)
[![Travis Build Status](https://img.shields.io/travis/lucas-clemente/quic-go/master.svg?style=flat-square&label=Travis+build)](https://travis-ci.org/lucas-clemente/quic-go)
[![CircleCI Build Status](https://img.shields.io/circleci/project/github/lucas-clemente/quic-go.svg?style=flat-square&label=CircleCI+build)](https://circleci.com/gh/lucas-clemente/quic-go)
[![Windows Build Status](https://img.shields.io/appveyor/ci/lucas-clemente/quic-go/master.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
[![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/)
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go.
## Roadmap
quic-go is compatible with the current version(s) of Google Chrome and QUIC as deployed on Google's servers. We're actively tracking the development of the Chrome code to ensure compatibility as the protocol evolves. In that process, we're dropping support for old QUIC versions.
As Google's QUIC versions are expected to converge towards the [IETF QUIC draft](https://github.com/quicwg/base-drafts), quic-go will eventually implement that draft.
## Guides
We currently support Go 1.9+.
Installing and updating dependencies:
go get -t -u ./...
Running tests:
go test ./...
### Running the example server
go run example/main.go -www /var/www/
Using the `quic_client` from chromium:
quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
Using Chrome:
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
### QUIC without HTTP/2
Take a look at [this echo example](example/echo/echo.go).
### Using the example client
go run example/client/main.go https://clemente.io
## Usage
### As a server
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
```go
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
h2quic.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
```
### As a client
See the [example client](example/client/main.go). Use a `h2quic.RoundTripper` as a `Transport` in a `http.Client`.
```go
http.Client{
Transport: &h2quic.RoundTripper{},
}
```
## Contributing
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/lucas-clemente/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.

View File

@ -0,0 +1,35 @@
version: "{build}"
os: Windows Server 2012 R2
environment:
GOPATH: c:\gopath
CGO_ENABLED: 0
TIMESCALE_FACTOR: 20
matrix:
- GOARCH: 386
- GOARCH: amd64
clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
install:
- rmdir c:\go /s /q
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.11.windows-amd64.zip
- 7z x go1.11.windows-amd64.zip -y -oC:\ > NUL
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
- echo %PATH%
- echo %GOPATH%
- git submodule update --init --recursive
- go get github.com/onsi/ginkgo/ginkgo
- go get github.com/onsi/gomega
- go version
- go env
- go get -v -t ./...
build_script:
- ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace -skipPackage benchmark,integrationtests
- ginkgo -randomizeAllSpecs -randomizeSuites -trace benchmark -- -samples=1
test: off
deploy: off

View File

@ -0,0 +1,26 @@
package benchmark
import (
"flag"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestBenchmark(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Benchmark Suite")
}
var (
size int // file size in MB, will be read from flags
samples int // number of samples for Measure, will be read from flags
)
func init() {
flag.IntVar(&size, "size", 50, "data length (in MB)")
flag.IntVar(&samples, "samples", 6, "number of samples")
flag.Parse()
}

View File

@ -0,0 +1,87 @@
package benchmark
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"math/rand"
"net"
quic "github.com/lucas-clemente/quic-go"
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
func init() {
var _ = Describe("Benchmarks", func() {
dataLen := size * /* MB */ 1e6
data := make([]byte, dataLen)
rand.Seed(GinkgoRandomSeed())
rand.Read(data) // no need to check for an error. math.Rand.Read never errors
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with version %s", version), func() {
Measure(fmt.Sprintf("transferring a %d MB file", size), func(b Benchmarker) {
var ln quic.Listener
serverAddr := make(chan net.Addr)
handshakeChan := make(chan struct{})
// start the server
go func() {
defer GinkgoRecover()
var err error
ln, err = quic.ListenAddr(
"localhost:0",
testdata.GetTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
serverAddr <- ln.Addr()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
// wait for the client to complete the handshake before sending the data
// this should not be necessary, but due to timing issues on the CIs, this is necessary to avoid sending too many undecryptable packets
<-handshakeChan
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(data)
Expect(err).ToNot(HaveOccurred())
err = str.Close()
Expect(err).ToNot(HaveOccurred())
}()
// start the client
addr := <-serverAddr
sess, err := quic.DialAddr(
addr.String(),
&tls.Config{InsecureSkipVerify: true},
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
close(handshakeChan)
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
buf := &bytes.Buffer{}
// measure the time it takes to download the dataLen bytes
// note we're measuring the time for the transfer, i.e. excluding the handshake
runtime := b.Time("transfer time", func() {
_, err := io.Copy(buf, str)
Expect(err).NotTo(HaveOccurred())
})
Expect(buf.Bytes()).To(Equal(data))
b.RecordValue("transfer rate [MB/s]", float64(dataLen)/1e6/runtime.Seconds())
ln.Close()
sess.Close()
}, samples)
})
}
})
}

View File

@ -0,0 +1,27 @@
package quic
import (
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var bufferPool sync.Pool
func getPacketBuffer() *[]byte {
return bufferPool.Get().(*[]byte)
}
func putPacketBuffer(buf *[]byte) {
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
panic("putPacketBuffer called with packet of wrong size!")
}
bufferPool.Put(buf)
}
func init() {
bufferPool.New = func() interface{} {
b := make([]byte, 0, protocol.MaxReceivePacketSize)
return &b
}
}

View File

@ -0,0 +1,21 @@
package quic
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Buffer Pool", func() {
It("returns buffers of cap", func() {
buf := *getPacketBuffer()
Expect(buf).To(HaveCap(int(protocol.MaxReceivePacketSize)))
})
It("panics if wrong-sized buffers are passed", func() {
Expect(func() {
putPacketBuffer(&[]byte{0})
}).To(Panic())
})
})

595
vendor/lucas-clemente/quic-go/client.go vendored Normal file
View File

@ -0,0 +1,595 @@
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
type client struct {
mutex sync.Mutex
conn connection
// If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
hostname string
packetHandlers packetHandlerManager
token []byte
numRetries int
versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config
mintConf *mint.Config
config *Config
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialVersion protocol.VersionNumber
version protocol.VersionNumber
handshakeChan chan struct{}
closeCallback func(protocol.ConnectionID)
session quicSession
logger utils.Logger
}
var _ packetHandler = &client{}
var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = protocol.GenerateConnectionID
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
)
// DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddr(
addr string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
return DialAddrContext(context.Background(), addr, tlsConf, config)
}
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
// The hostname for SNI is taken from the given address.
func DialAddrContext(
ctx context.Context,
addr string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
}
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
// The host parameter is used for SNI.
func DialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
}
func dialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
createdPacketConn bool,
) (Session, error) {
config = populateClientConfig(config, createdPacketConn)
if !createdPacketConn {
for _, v := range config.Versions {
if v == protocol.Version44 {
return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
}
}
}
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
if err != nil {
return nil, err
}
c.packetHandlers = packetHandlers
if err := c.dial(ctx); err != nil {
return nil, err
}
return c.session, nil
}
func newClient(
pconn net.PacketConn,
remoteAddr net.Addr,
config *Config,
tlsConf *tls.Config,
host string,
closeCallback func(protocol.ConnectionID),
createdPacketConn bool,
) (*client, error) {
var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
var err error
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
}
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
onClose := func(protocol.ConnectionID) {}
if closeCallback != nil {
onClose = closeCallback
}
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn,
hostname: hostname,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, c.generateConnectionIDs()
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
handshakeTimeout := protocol.DefaultHandshakeTimeout
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
idleTimeout := protocol.DefaultIdleTimeout
if config.IdleTimeout != 0 {
idleTimeout = config.IdleTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
}
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDLen := config.ConnectionIDLength
if connIDLen == 0 && !createdPacketConn {
connIDLen = protocol.DefaultConnectionIDLength
}
for _, v := range versions {
if v == protocol.Version44 {
connIDLen = 0
}
}
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
IdleTimeout: idleTimeout,
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
ConnectionIDLength: connIDLen,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
KeepAlive: config.KeepAlive,
}
}
func (c *client) generateConnectionIDs() error {
connIDLen := protocol.ConnectionIDLenGQUIC
if c.version.UsesTLS() {
connIDLen = c.config.ConnectionIDLength
}
srcConnID, err := generateConnectionID(connIDLen)
if err != nil {
return err
}
destConnID := srcConnID
if c.version.UsesTLS() {
destConnID, err = generateConnectionIDForInitial()
if err != nil {
return err
}
}
c.srcConnID = srcConnID
c.destConnID = destConnID
if c.version == protocol.Version44 {
c.srcConnID = nil
}
return nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
var err error
if c.version.UsesTLS() {
err = c.dialTLS(ctx)
} else {
err = c.dialGQUIC(ctx)
}
return err
}
func (c *client) dialGQUIC(ctx context.Context) error {
if err := c.createNewGQUICSession(); err != nil {
return err
}
err := c.establishSecureConnection(ctx)
if err == errCloseSessionForNewVersion {
return c.dial(ctx)
}
return err
}
func (c *client) dialTLS(ctx context.Context) error {
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
DisableMigration: true,
}
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
}
mintConf.ExtensionHandler = extHandler
mintConf.ServerName = c.hostname
c.mintConf = mintConf
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
err = c.establishSecureConnection(ctx)
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
return c.dial(ctx)
}
return err
}
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection(ctx context.Context) error {
errorChan := make(chan error, 1)
go func() {
err := c.session.run() // returns as soon as the session is closed
if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
c.conn.Close()
}
errorChan <- err
}()
select {
case <-ctx.Done():
// The session will send a PeerGoingAway error to the server.
c.session.Close()
return ctx.Err()
case err := <-errorChan:
return err
case <-c.handshakeChan:
// handshake successfully completed
return nil
}
}
func (c *client) handlePacket(p *receivedPacket) {
if err := c.handlePacketImpl(p); err != nil {
c.logger.Errorf("error handling packet: %s", err)
}
}
func (c *client) handlePacketImpl(p *receivedPacket) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// handle Version Negotiation Packets
if p.header.IsVersionNegotiation {
err := c.handleVersionNegotiationPacket(p.header)
if err != nil {
c.session.destroy(err)
}
// version negotiation packets have no payload
return err
}
if !c.version.UsesIETFHeaderFormat() {
connID := p.header.DestConnectionID
// reject packets with truncated connection id if we didn't request truncation
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
return errors.New("received packet with truncated connection ID, but didn't request truncation")
}
// reject packets with the wrong connection ID
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
}
if p.header.ResetFlag {
return c.handlePublicReset(p)
}
} else {
// reject packets with the wrong connection ID
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
}
if p.header.IsLongHeader {
switch p.header.Type {
case protocol.PacketTypeRetry:
c.handleRetryPacket(p.header)
return nil
case protocol.PacketTypeHandshake, protocol.PacketType0RTT:
default:
return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
}
}
// this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated {
c.versionNegotiated = true
}
c.session.handlePacket(p)
return nil
}
func (c *client) handlePublicReset(p *receivedPacket) error {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
return errors.New("Received a spoofed Public Reset")
}
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
if err != nil {
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
}
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
return nil
}
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
c.logger.Debugf("Received a delayed Version Negotiation Packet.")
return nil
}
for _, v := range hdr.SupportedVersions {
if v == c.version {
// the version negotiation packet contains the version that we offered
// this might be a packet sent by an attacker (or by a terribly broken server implementation)
// ignore it
return nil
}
}
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
return qerr.InvalidVersion
}
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
if err := c.generateConnectionIDs(); err != nil {
return err
}
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion)
return nil
}
func (c *client) handleRetryPacket(hdr *wire.Header) {
c.logger.Debugf("<- Received Retry")
hdr.Log(c.logger)
// A server that performs multiple retries must use a source connection ID of at least 8 bytes.
// Only a server that won't send additional Retries can use shorter connection IDs.
if hdr.OrigDestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
c.logger.Debugf("Received a Retry with a too short Original Destination Connection ID: %d bytes, must have at least %d bytes.", hdr.OrigDestConnectionID.Len(), protocol.MinConnectionIDLenInitial)
return
}
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Received spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
return
}
c.numRetries++
if c.numRetries > protocol.MaxRetries {
c.session.destroy(qerr.CryptoTooManyRejects)
return
}
c.destConnID = hdr.SrcConnectionID
c.token = hdr.Token
c.session.destroy(errCloseSessionForRetry)
}
func (c *client) createNewGQUICSession() error {
c.mutex.Lock()
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: c.closeCallback,
}
sess, err := newClientSession(
c.conn,
runner,
c.hostname,
c.version,
c.destConnID,
c.srcConnID,
c.tlsConf,
c.config,
c.initialVersion,
c.negotiatedVersions,
c.logger,
)
if err != nil {
return err
}
c.session = sess
c.packetHandlers.Add(c.srcConnID, c)
if c.config.RequestConnectionIDOmission {
c.packetHandlers.Add(protocol.ConnectionID{}, c)
}
return nil
}
func (c *client) createNewTLSSession(
paramsChan <-chan handshake.TransportParameters,
version protocol.VersionNumber,
) error {
c.mutex.Lock()
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: c.closeCallback,
}
sess, err := newTLSClientSession(
c.conn,
runner,
c.token,
c.destConnID,
c.srcConnID,
c.config,
c.mintConf,
paramsChan,
1,
c.logger,
c.version,
)
if err != nil {
return err
}
c.session = sess
c.packetHandlers.Add(c.srcConnID, c)
return nil
}
func (c *client) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.session == nil {
return nil
}
return c.session.Close()
}
func (c *client) destroy(e error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.session == nil {
return
}
c.session.destroy(e)
}
func (c *client) GetVersion() protocol.VersionNumber {
c.mutex.Lock()
v := c.version
c.mutex.Unlock()
return v
}
func (c *client) GetPerspective() protocol.Perspective {
return protocol.PerspectiveClient
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,18 @@
coverage:
round: nearest
ignore:
- streams_map_incoming_bidi.go
- streams_map_incoming_uni.go
- streams_map_outgoing_bidi.go
- streams_map_outgoing_uni.go
- h2quic/gzipreader.go
- h2quic/response.go
- internal/ackhandler/packet_linkedlist.go
- internal/utils/byteinterval_linkedlist.go
- internal/utils/packetinterval_linkedlist.go
- internal/utils/linkedlist/linkedlist.go
status:
project:
default:
threshold: 0.5
patch: false

54
vendor/lucas-clemente/quic-go/conn.go vendored Normal file
View File

@ -0,0 +1,54 @@
package quic
import (
"net"
"sync"
)
type connection interface {
Write([]byte) error
Read([]byte) (int, net.Addr, error)
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
SetCurrentRemoteAddr(net.Addr)
}
type conn struct {
mutex sync.RWMutex
pconn net.PacketConn
currentAddr net.Addr
}
var _ connection = &conn{}
func (c *conn) Write(p []byte) error {
_, err := c.pconn.WriteTo(p, c.currentAddr)
return err
}
func (c *conn) Read(p []byte) (int, net.Addr, error) {
return c.pconn.ReadFrom(p)
}
func (c *conn) SetCurrentRemoteAddr(addr net.Addr) {
c.mutex.Lock()
c.currentAddr = addr
c.mutex.Unlock()
}
func (c *conn) LocalAddr() net.Addr {
return c.pconn.LocalAddr()
}
func (c *conn) RemoteAddr() net.Addr {
c.mutex.RLock()
addr := c.currentAddr
c.mutex.RUnlock()
return addr
}
func (c *conn) Close() error {
return c.pconn.Close()
}

View File

@ -0,0 +1,119 @@
package quic
import (
"bytes"
"errors"
"net"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockPacketConn struct {
addr net.Addr
dataToRead chan []byte
dataReadFrom net.Addr
readErr error
dataWritten bytes.Buffer
dataWrittenTo net.Addr
closed bool
}
func newMockPacketConn() *mockPacketConn {
return &mockPacketConn{
dataToRead: make(chan []byte, 1000),
}
}
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
if c.readErr != nil {
return 0, nil, c.readErr
}
data, ok := <-c.dataToRead
if !ok {
return 0, nil, errors.New("connection closed")
}
n := copy(b, data)
return n, c.dataReadFrom, nil
}
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
c.dataWrittenTo = addr
return c.dataWritten.Write(b)
}
func (c *mockPacketConn) Close() error {
if !c.closed {
close(c.dataToRead)
}
c.closed = true
return nil
}
func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr }
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { panic("not implemented") }
var _ net.PacketConn = &mockPacketConn{}
var _ = Describe("Connection", func() {
var c *conn
var packetConn *mockPacketConn
BeforeEach(func() {
addr := &net.UDPAddr{
IP: net.IPv4(192, 168, 100, 200),
Port: 1337,
}
packetConn = newMockPacketConn()
c = &conn{
currentAddr: addr,
pconn: packetConn,
}
})
It("writes", func() {
err := c.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(packetConn.dataWritten.Bytes()).To(Equal([]byte("foobar")))
Expect(packetConn.dataWrittenTo.String()).To(Equal("192.168.100.200:1337"))
})
It("reads", func() {
packetConn.dataToRead <- []byte("foo")
packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336}
p := make([]byte, 10)
n, raddr, err := c.Read(p)
Expect(err).ToNot(HaveOccurred())
Expect(raddr.String()).To(Equal("127.0.0.1:1336"))
Expect(n).To(Equal(3))
Expect(p[0:3]).To(Equal([]byte("foo")))
})
It("gets the remote address", func() {
Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337"))
})
It("gets the local address", func() {
addr := &net.UDPAddr{
IP: net.IPv4(192, 168, 0, 1),
Port: 1234,
}
packetConn.addr = addr
Expect(c.LocalAddr()).To(Equal(addr))
})
It("changes the remote address", func() {
addr := &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 7331,
}
c.SetCurrentRemoteAddr(addr)
Expect(c.RemoteAddr().String()).To(Equal(addr.String()))
})
It("closes", func() {
err := c.Close()
Expect(err).ToNot(HaveOccurred())
Expect(packetConn.closed).To(BeTrue())
})
})

View File

@ -0,0 +1,41 @@
package quic
import (
"io"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStream interface {
StreamID() protocol.StreamID
io.Reader
io.Writer
handleStreamFrame(*wire.StreamFrame) error
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
setReadOffset(protocol.ByteCount)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type cryptoStreamImpl struct {
*stream
}
var _ cryptoStream = &cryptoStreamImpl{}
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream {
str := newStream(version.CryptoStreamID(), sender, flowController, version)
return &cryptoStreamImpl{str}
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPos = offset
}

View File

@ -0,0 +1,26 @@
package quic
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Crypto Stream", func() {
var (
str *cryptoStreamImpl
mockSender *MockStreamSender
)
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
str = newCryptoStream(mockSender, nil, protocol.VersionWhatever).(*cryptoStreamImpl)
})
It("sets the read offset", func() {
str.setReadOffset(0x42)
Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42)))
Expect(str.receiveStream.frameQueue.readPos).To(Equal(protocol.ByteCount(0x42)))
})
})

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

View File

@ -0,0 +1,65 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg width="601px" height="242px" viewBox="0 0 601 242" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<!-- Generator: Sketch 3.7.2 (28276) - http://www.bohemiancoding.com/sketch -->
<title>quic</title>
<desc>Created with Sketch.</desc>
<defs></defs>
<g id="Page-1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="Group-4-Copy" transform="translate(586.087964, 15.755663) rotate(-26.000000) translate(-586.087964, -15.755663) translate(573.036290, 1.253803)">
<path d="M12.0494985,28.9428124 C-12.9338918,21.9172827 5.64340536,-9.86582259 25.705693,3.20698278 L12.0494985,28.9428124 L12.0494985,28.9428124 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.861962, 14.514517) scale(-1, 1) translate(-12.861962, -14.514517) "></path>
<path d="M12.1335889,20.1646293 C8.83752326,18.4210914 6.41944167,16.0531448 8.43561456,12.2408129 C10.3022111,8.71147453 13.769813,9.09473797 17.066293,10.8382759 L12.1335889,20.1646293 L12.1335889,20.1646293 Z" id="Shape" fill="#000000" transform="translate(12.382769, 14.854037) scale(-1, 1) translate(-12.382769, -14.854037) "></path>
</g>
<g id="Group-3-Copy" transform="translate(499.346306, 15.541651) rotate(25.000000) translate(-499.346306, -15.541651) translate(486.346306, 0.541651)">
<path d="M0.262217125,3.72242032 C20.0373679,-10.1674613 37.9073886,21.170644 14.9248407,29.0555125 L0.262217125,3.72242032 L0.262217125,3.72242032 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.997366, 14.636187) scale(-1, 1) translate(-12.997366, -14.636187) "></path>
<path d="M16.7752477,19.7502904 C20.0713133,18.0067525 22.4893949,15.6388059 20.473222,11.8264741 C18.6066255,8.29713567 15.1390235,8.68039911 11.8425436,10.423937 L16.7752477,19.7502904 L16.7752477,19.7502904 Z" id="Shape" fill="#000000" transform="translate(16.526068, 14.439699) scale(-1, 1) translate(-16.526068, -14.439699) "></path>
</g>
<g id="Group-3" transform="translate(1.896652, 21.000000)">
<path d="M0.262217125,3.72242032 C20.0373679,-10.1674613 37.9073886,21.170644 14.9248407,29.0555125 L0.262217125,3.72242032 L0.262217125,3.72242032 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.997366, 14.636187) scale(-1, 1) translate(-12.997366, -14.636187) "></path>
<path d="M16.7752477,19.7502904 C20.0713133,18.0067525 22.4893949,15.6388059 20.473222,11.8264741 C18.6066255,8.29713567 15.1390235,8.68039911 11.8425436,10.423937 L16.7752477,19.7502904 L16.7752477,19.7502904 Z" id="Shape" fill="#000000" transform="translate(16.526068, 14.439699) scale(-1, 1) translate(-16.526068, -14.439699) "></path>
</g>
<g id="Group-4" transform="translate(117.000000, 22.000000)">
<path d="M12.0494985,28.9428124 C-12.9338918,21.9172827 5.64340536,-9.86582259 25.705693,3.20698278 L12.0494985,28.9428124 L12.0494985,28.9428124 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.861962, 14.514517) scale(-1, 1) translate(-12.861962, -14.514517) "></path>
<path d="M12.1335889,20.1646293 C8.83752326,18.4210914 6.41944167,16.0531448 8.43561456,12.2408129 C10.3022111,8.71147453 13.769813,9.09473797 17.066293,10.8382759 L12.1335889,20.1646293 L12.1335889,20.1646293 Z" id="Shape" fill="#000000" transform="translate(12.382769, 14.854037) scale(-1, 1) translate(-12.382769, -14.854037) "></path>
</g>
<g id="Group-2" transform="translate(26.000000, 40.000000)">
<path d="M0.189217386,22.0262538 C4.67982195,48.130845 47.3500951,41.2287883 41.2211947,14.9008684 C35.725404,-8.70815982 -1.30323119,-2.17320732 0.189217386,22.0262538" id="Shape" stroke="#000000" stroke-width="2.9081" fill="#FFFFFF" transform="translate(20.983709, 19.323749) scale(-1, 1) translate(-20.983709, -19.323749) "></path>
<path d="M52.375353,26.2483668 C58.1955709,48.9748532 94.5815666,43.1562926 93.210105,20.3593686 C91.566837,-6.94349061 46.8107821,-1.67517201 52.375353,26.2483668" id="Shape" stroke="#000000" stroke-width="2.8214" fill="#FFFFFF" transform="translate(72.572290, 21.519163) scale(-1, 1) translate(-72.572290, -21.519163) "></path>
<path d="M44.873555,53.3511003 C44.8926146,56.7449499 45.6446396,60.55521 45.0028287,64.1657589 C44.1364461,65.803226 42.4368281,65.9764197 40.9717259,66.6381188 C38.9456089,66.3203209 37.2418475,64.9898788 36.429329,63.0946929 C35.9093337,58.9736786 36.6232396,54.9835954 36.754585,50.8609237 L44.873555,53.3511003 L44.873555,53.3511003 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(40.751798, 58.749521) scale(-1, 1) translate(-40.751798, -58.749521) "></path>
<g id="Group" transform="translate(82.660603, 22.581468) scale(-1, 1) translate(-82.660603, -22.581468) translate(76.238350, 15.744877)">
<ellipse id="Oval" fill="#000000" cx="6.34932869" cy="6.79681466" rx="6.14423095" ry="6.65511077"></ellipse>
<ellipse id="Oval" fill="#FFFFFF" cx="9.1440443" cy="8.29879302" rx="1.44852865" ry="1.69133123"></ellipse>
</g>
<g id="Group" transform="translate(31.282584, 20.924112) scale(-1, 1) translate(-31.282584, -20.924112) translate(24.860332, 14.087521)">
<ellipse id="Oval" fill="#000000" cx="6.45167039" cy="6.79681466" rx="6.04188925" ry="6.65511077"></ellipse>
<ellipse id="Oval" fill="#FFFFFF" cx="9.19998004" cy="8.29879302" rx="1.424497" ry="1.69133123"></ellipse>
</g>
<path d="M45.9112876,52.98151 C43.2305151,59.4783433 47.4062222,72.4699383 54.6799409,62.8875235 C54.1599456,58.7665092 54.8738514,54.776426 55.0051969,50.6537543 L45.9112876,52.98151 L45.9112876,52.98151 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(50.049883, 58.475140) scale(-1, 1) translate(-50.049883, -58.475140) "></path>
<g id="Group" transform="translate(44.127089, 43.712750) scale(-1, 1) translate(-44.127089, -43.712750) translate(27.346365, 32.732770)">
<path d="M7.63667953,7.73694953 C2.64016722,8.16288988 -1.44397093,14.1036805 1.15393372,18.8035261 C4.59418928,25.0285532 12.2731314,18.2528698 17.0558448,18.8876369 C22.5603366,19.0003371 27.0704151,24.7078549 31.4914107,19.9193407 C36.40837,14.593429 29.3745535,9.4063208 23.8771055,7.0872662 L7.63667953,7.73694953 L7.63667953,7.73694953 Z" id="Shape" stroke="#231F20" stroke-width="3" fill="#F6D2A2"></path>
<path d="M7.00771314,7.47674473 C6.63770854,-1.1792084 23.1412397,-2.2614615 25.0902897,4.98408215 C27.0343676,12.2126379 7.82023164,13.891539 7.00771314,7.47674473 C6.35927282,2.3542734 7.00771314,7.47674473 7.00771314,7.47674473 L7.00771314,7.47674473 Z" id="Shape" fill="#000000"></path>
</g>
</g>
<g id="Group-2-Copy" transform="translate(544.820291, 73.354278) scale(-1, 1) translate(-544.820291, -73.354278) translate(498.000000, 40.000000)">
<path d="M0.189217386,22.0262538 C4.67982195,48.130845 47.3500951,41.2287883 41.2211947,14.9008684 C35.725404,-8.70815982 -1.30323119,-2.17320732 0.189217386,22.0262538" id="Shape" stroke="#000000" stroke-width="2.9081" fill="#FFFFFF" transform="translate(20.983709, 19.323749) scale(-1, 1) translate(-20.983709, -19.323749) "></path>
<path d="M52.375353,26.2483668 C58.1955709,48.9748532 94.5815666,43.1562926 93.210105,20.3593686 C91.566837,-6.94349061 46.8107821,-1.67517201 52.375353,26.2483668" id="Shape" stroke="#000000" stroke-width="2.8214" fill="#FFFFFF" transform="translate(72.572290, 21.519163) scale(-1, 1) translate(-72.572290, -21.519163) "></path>
<path d="M44.873555,53.3511003 C44.8926146,56.7449499 45.6446396,60.55521 45.0028287,64.1657589 C44.1364461,65.803226 42.4368281,65.9764197 40.9717259,66.6381188 C38.9456089,66.3203209 37.2418475,64.9898788 36.429329,63.0946929 C35.9093337,58.9736786 36.6232396,54.9835954 36.754585,50.8609237 L44.873555,53.3511003 L44.873555,53.3511003 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(40.751798, 58.749521) scale(-1, 1) translate(-40.751798, -58.749521) "></path>
<g id="Group" transform="translate(82.660603, 22.581468) scale(-1, 1) translate(-82.660603, -22.581468) translate(76.238350, 15.744877)">
<ellipse id="Oval" fill="#000000" cx="6.34932869" cy="6.79681466" rx="6.14423095" ry="6.65511077"></ellipse>
<ellipse id="Oval" fill="#FFFFFF" cx="9.1440443" cy="8.29879302" rx="1.44852865" ry="1.69133123"></ellipse>
</g>
<g id="Group" transform="translate(31.282584, 20.924112) scale(-1, 1) translate(-31.282584, -20.924112) translate(24.860332, 14.087521)">
<ellipse id="Oval" fill="#000000" cx="6.45167039" cy="6.79681466" rx="6.04188925" ry="6.65511077"></ellipse>
<ellipse id="Oval" fill="#FFFFFF" cx="9.19998004" cy="8.29879302" rx="1.424497" ry="1.69133123"></ellipse>
</g>
<path d="M45.9112876,52.98151 C43.2305151,59.4783433 47.4062222,72.4699383 54.6799409,62.8875235 C54.1599456,58.7665092 54.8738514,54.776426 55.0051969,50.6537543 L45.9112876,52.98151 L45.9112876,52.98151 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(50.049883, 58.475140) scale(-1, 1) translate(-50.049883, -58.475140) "></path>
<g id="Group" transform="translate(44.127089, 43.712750) scale(-1, 1) translate(-44.127089, -43.712750) translate(27.346365, 32.732770)">
<path d="M7.63667953,7.73694953 C2.64016722,8.16288988 -1.44397093,14.1036805 1.15393372,18.8035261 C4.59418928,25.0285532 12.2731314,18.2528698 17.0558448,18.8876369 C22.5603366,19.0003371 27.0704151,24.7078549 31.4914107,19.9193407 C36.40837,14.593429 29.3745535,9.4063208 23.8771055,7.0872662 L7.63667953,7.73694953 L7.63667953,7.73694953 Z" id="Shape" stroke="#231F20" stroke-width="3" fill="#F6D2A2"></path>
<path d="M7.00771314,7.47674473 C6.63770854,-1.1792084 23.1412397,-2.2614615 25.0902897,4.98408215 C27.0343676,12.2126379 7.82023164,13.891539 7.00771314,7.47674473 C6.35927282,2.3542734 7.00771314,7.47674473 7.00771314,7.47674473 L7.00771314,7.47674473 Z" id="Shape" fill="#000000"></path>
</g>
</g>
<text id="QUIC" font-family="SourceCodePro-Light, Source Code Pro" font-size="288" font-weight="300" letter-spacing="-20" fill="#000000">
<tspan x="-13.6" y="197">QUI</tspan>
<tspan x="444.8" y="197">C</tspan>
</text>
</g>
</svg>

After

Width:  |  Height:  |  Size: 11 KiB

View File

@ -0,0 +1,9 @@
FROM scratch
VOLUME /certs
VOLUME /www
EXPOSE 6121
ADD main /main
CMD ["/main", "-bind=0.0.0.0", "-certpath=/certs/", "-www=/www"]

View File

@ -0,0 +1,7 @@
# About the certificate
Yes, this folder contains a private key and a certificate for quic.clemente.io.
Unfortunately we need a valid certificate for the integration tests with Chrome and `quic_client`. No important data is served on the "real" `quic.clemente.io` (only a test page), and the MITM problem is imho negligible.
If you figure out a way to test with Chrome without having a cert and key here, let us now in an issue.

View File

@ -0,0 +1,66 @@
package main
import (
"bytes"
"flag"
"io"
"net/http"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
func main() {
verbose := flag.Bool("v", false, "verbose")
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse()
urls := flag.Args()
logger := utils.DefaultLogger
if *verbose {
logger.SetLogLevel(utils.LogLevelDebug)
} else {
logger.SetLogLevel(utils.LogLevelInfo)
}
logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
}
roundTripper := &h2quic.RoundTripper{
QuicConfig: &quic.Config{Versions: versions},
}
defer roundTripper.Close()
hclient := &http.Client{
Transport: roundTripper,
}
var wg sync.WaitGroup
wg.Add(len(urls))
for _, addr := range urls {
logger.Infof("GET %s", addr)
go func(addr string) {
rsp, err := hclient.Get(addr)
if err != nil {
panic(err)
}
logger.Infof("Got response for %s: %#v", addr, rsp)
body := &bytes.Buffer{}
_, err = io.Copy(body, rsp.Body)
if err != nil {
panic(err)
}
logger.Infof("Request Body:")
logger.Infof("%s", body.Bytes())
wg.Done()
}(addr)
}
wg.Wait()
}

View File

@ -0,0 +1,105 @@
package main
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"log"
"math/big"
quic "github.com/lucas-clemente/quic-go"
)
const addr = "localhost:4242"
const message = "foobar"
// We start a server echoing data on the first stream the client opens,
// then connect with a client, send the message, and wait for its receipt.
func main() {
go func() { log.Fatal(echoServer()) }()
err := clientMain()
if err != nil {
panic(err)
}
}
// Start a server that echos all data on the first stream opened by the client
func echoServer() error {
listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil)
if err != nil {
return err
}
sess, err := listener.Accept()
if err != nil {
return err
}
stream, err := sess.AcceptStream()
if err != nil {
panic(err)
}
// Echo through the loggingWriter
_, err = io.Copy(loggingWriter{stream}, stream)
return err
}
func clientMain() error {
session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil)
if err != nil {
return err
}
stream, err := session.OpenStreamSync()
if err != nil {
return err
}
fmt.Printf("Client: Sending '%s'\n", message)
_, err = stream.Write([]byte(message))
if err != nil {
return err
}
buf := make([]byte, len(message))
_, err = io.ReadFull(stream, buf)
if err != nil {
return err
}
fmt.Printf("Client: Got '%s'\n", buf)
return nil
}
// A wrapper for io.Writer that also logs the message.
type loggingWriter struct{ io.Writer }
func (w loggingWriter) Write(b []byte) (int, error) {
fmt.Printf("Server: Got '%s'\n", string(b))
return w.Writer.Write(b)
}
// Setup a bare-bones TLS config for the server
func generateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{Certificates: []tls.Certificate{tlsCert}}
}

View File

@ -0,0 +1,62 @@
-----BEGIN CERTIFICATE-----
MIIGCzCCBPOgAwIBAgISA6pwet1vq9IjS9ThcggZYGfyMA0GCSqGSIb3DQEBCwUA
MEoxCzAJBgNVBAYTAlVTMRYwFAYDVQQKEw1MZXQncyBFbmNyeXB0MSMwIQYDVQQD
ExpMZXQncyBFbmNyeXB0IEF1dGhvcml0eSBYMzAeFw0xODA2MDkwODI2MjVaFw0x
ODA5MDcwODI2MjVaMBsxGTAXBgNVBAMTEHF1aWMuY2xlbWVudGUuaW8wggEiMA0G
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDqON9RKCMK4dHxLTLiv3n+b/v4KCbY
Fo38rxCcv4/qoWI1Zoz4XMyPnTZPrzVMz6rhDKWTve1v7g+itgQDp1OoRo5ihtqC
m9Dr1Ed+qb3hHtFRBBOTiWQy1Y0fvUDYGlSHQ4R7xXYVzCYkQ9zopqr2otJp/9ZA
1Yy3ATITSgcds9wJaAOpkSxCx+D7cqpGbaNtxogRZ4ZT1vVs3l7Kr26fvnALMM4K
nf2Rpq1ZjgPRoLNcfUWCZO7gr+VNCFaI2AQRaAzZkSXAk0B91waTtaFD87kYXl9Z
8Jgx1PhJnO8WBUz/7+OAO6M+kuT78e6FltMN004Q8WmuNww6PzmVwnmpAgMBAAGj
ggMYMIIDFDAOBgNVHQ8BAf8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG
AQUFBwMCMAwGA1UdEwEB/wQCMAAwHQYDVR0OBBYEFIs3yJaLnEw7MRtNf9EbzESy
TrHwMB8GA1UdIwQYMBaAFKhKamMEfd265tE5t6ZFZe/zqOyhMG8GCCsGAQUFBwEB
BGMwYTAuBggrBgEFBQcwAYYiaHR0cDovL29jc3AuaW50LXgzLmxldHNlbmNyeXB0
Lm9yZzAvBggrBgEFBQcwAoYjaHR0cDovL2NlcnQuaW50LXgzLmxldHNlbmNyeXB0
Lm9yZy8wGwYDVR0RBBQwEoIQcXVpYy5jbGVtZW50ZS5pbzCB/gYDVR0gBIH2MIHz
MAgGBmeBDAECATCB5gYLKwYBBAGC3xMBAQEwgdYwJgYIKwYBBQUHAgEWGmh0dHA6
Ly9jcHMubGV0c2VuY3J5cHQub3JnMIGrBggrBgEFBQcCAjCBngyBm1RoaXMgQ2Vy
dGlmaWNhdGUgbWF5IG9ubHkgYmUgcmVsaWVkIHVwb24gYnkgUmVseWluZyBQYXJ0
aWVzIGFuZCBvbmx5IGluIGFjY29yZGFuY2Ugd2l0aCB0aGUgQ2VydGlmaWNhdGUg
UG9saWN5IGZvdW5kIGF0IGh0dHBzOi8vbGV0c2VuY3J5cHQub3JnL3JlcG9zaXRv
cnkvMIIBBAYKKwYBBAHWeQIEAgSB9QSB8gDwAHYA23Sv7ssp7LH+yj5xbSzluaq7
NveEcYPHXZ1PN7Yfv2QAAAFj495INQAABAMARzBFAiEApF0BFCWyGIUrJrsYuugt
tshGVdg2+7f6d4B1D/xF2s0CIGGsVL2nRlXJXrkk3aa83lH4HzP9vcQSnMFHdXOf
9XeZAHYAKTxRllTIOWW6qlD8WAfUt2+/WHopctykwwz05UVH9HgAAAFj495IQwAA
BAMARzBFAiEAkV4gbM6hucL7ZwqTzb5fKxYhk6WHr5y8pzClZD3qqv4CIBYH7MSA
P05CXv7tHiHEizIvhWJJvVa2E6XLjbDRQMnUMA0GCSqGSIb3DQEBCwUAA4IBAQA2
CALjtlxGXkkfKsRgKDoPpzg/IAl7crq5OrGGwW/bxbeDeiRHVt0Hlhr+0XPYlh/A
m8qBlg7TMHJa2zIt7wkG/MX3d9bO2bxXxsdjfMXLtxQu7eZ5nzGuXL8WGnZwtSqz
M5pOF/AU1JQojIhehKCeqqUi5UobxXUm+9D6OVmr8s732X3n/TL6pgsWRhay9tjB
kdZ9TSe0tLYyXnbwHp0rgwNWOMMN1Tc+Fpqc8UlrCq5REb0bLIQ6A2IJH08MEPWG
ukXHPAaHLz9oB1emR3flCoMKB0KrXUpDFXemOIPN6QVBO16LNNuRffKwXnzp60+b
MoB4Krxrab7TzlfT+HnP
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIEkjCCA3qgAwIBAgIQCgFBQgAAAVOFc2oLheynCDANBgkqhkiG9w0BAQsFADA/
MSQwIgYDVQQKExtEaWdpdGFsIFNpZ25hdHVyZSBUcnVzdCBDby4xFzAVBgNVBAMT
DkRTVCBSb290IENBIFgzMB4XDTE2MDMxNzE2NDA0NloXDTIxMDMxNzE2NDA0Nlow
SjELMAkGA1UEBhMCVVMxFjAUBgNVBAoTDUxldCdzIEVuY3J5cHQxIzAhBgNVBAMT
GkxldCdzIEVuY3J5cHQgQXV0aG9yaXR5IFgzMIIBIjANBgkqhkiG9w0BAQEFAAOC
AQ8AMIIBCgKCAQEAnNMM8FrlLke3cl03g7NoYzDq1zUmGSXhvb418XCSL7e4S0EF
q6meNQhY7LEqxGiHC6PjdeTm86dicbp5gWAf15Gan/PQeGdxyGkOlZHP/uaZ6WA8
SMx+yk13EiSdRxta67nsHjcAHJyse6cF6s5K671B5TaYucv9bTyWaN8jKkKQDIZ0
Z8h/pZq4UmEUEz9l6YKHy9v6Dlb2honzhT+Xhq+w3Brvaw2VFn3EK6BlspkENnWA
a6xK8xuQSXgvopZPKiAlKQTGdMDQMc2PMTiVFrqoM7hD8bEfwzB/onkxEz0tNvjj
/PIzark5McWvxI0NHWQWM6r6hCm21AvA2H3DkwIDAQABo4IBfTCCAXkwEgYDVR0T
AQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAYYwfwYIKwYBBQUHAQEEczBxMDIG
CCsGAQUFBzABhiZodHRwOi8vaXNyZy50cnVzdGlkLm9jc3AuaWRlbnRydXN0LmNv
bTA7BggrBgEFBQcwAoYvaHR0cDovL2FwcHMuaWRlbnRydXN0LmNvbS9yb290cy9k
c3Ryb290Y2F4My5wN2MwHwYDVR0jBBgwFoAUxKexpHsscfrb4UuQdf/EFWCFiRAw
VAYDVR0gBE0wSzAIBgZngQwBAgEwPwYLKwYBBAGC3xMBAQEwMDAuBggrBgEFBQcC
ARYiaHR0cDovL2Nwcy5yb290LXgxLmxldHNlbmNyeXB0Lm9yZzA8BgNVHR8ENTAz
MDGgL6AthitodHRwOi8vY3JsLmlkZW50cnVzdC5jb20vRFNUUk9PVENBWDNDUkwu
Y3JsMB0GA1UdDgQWBBSoSmpjBH3duubRObemRWXv86jsoTANBgkqhkiG9w0BAQsF
AAOCAQEA3TPXEfNjWDjdGBX7CVW+dla5cEilaUcne8IkCJLxWh9KEik3JHRRHGJo
uM2VcGfl96S8TihRzZvoroed6ti6WqEBmtzw3Wodatg+VyOeph4EYpr/1wXKtx8/
wApIvJSwtmVi4MFU5aMqrSDE6ea73Mj2tcMyo5jMd6jmeWUHK8so/joWUoHOUgwu
X4Po1QYz+3dszkDqMp4fklxBwXRsW10KXzPMTZ+sOPAveyxindmjkW8lGy+QsRlG
PfZ+G6Z6h7mjem0Y+iWlkYcV4PIWL1iwBi8saCbGS5jN2p8M+X+Q7UNKEkROb3N6
KOqkqm57TH2H3eDJAkSnh6/DNFu0Qg==
-----END CERTIFICATE-----

View File

@ -0,0 +1,174 @@
package main
import (
"crypto/md5"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"mime/multipart"
"net/http"
"path"
"runtime"
"strings"
"sync"
_ "net/http/pprof"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type binds []string
func (b binds) String() string {
return strings.Join(b, ",")
}
func (b *binds) Set(v string) error {
*b = strings.Split(v, ",")
return nil
}
// Size is needed by the /demo/upload handler to determine the size of the uploaded file
type Size interface {
Size() int64
}
func init() {
http.HandleFunc("/demo/tile", func(w http.ResponseWriter, r *http.Request) {
// Small 40x40 png
w.Write([]byte{
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d,
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x28,
0x01, 0x03, 0x00, 0x00, 0x00, 0xb6, 0x30, 0x2a, 0x2e, 0x00, 0x00, 0x00,
0x03, 0x50, 0x4c, 0x54, 0x45, 0x5a, 0xc3, 0x5a, 0xad, 0x38, 0xaa, 0xdb,
0x00, 0x00, 0x00, 0x0b, 0x49, 0x44, 0x41, 0x54, 0x78, 0x01, 0x63, 0x18,
0x61, 0x00, 0x00, 0x00, 0xf0, 0x00, 0x01, 0xe2, 0xb8, 0x75, 0x22, 0x00,
0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae, 0x42, 0x60, 0x82,
})
})
http.HandleFunc("/demo/tiles", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "<html><head><style>img{width:40px;height:40px;}</style></head><body>")
for i := 0; i < 200; i++ {
fmt.Fprintf(w, `<img src="/demo/tile?cachebust=%d">`, i)
}
io.WriteString(w, "</body></html>")
})
http.HandleFunc("/demo/echo", func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
fmt.Printf("error reading body while handling /echo: %s\n", err.Error())
}
w.Write(body)
})
// accept file uploads and return the MD5 of the uploaded file
// maximum accepted file size is 1 GB
http.HandleFunc("/demo/upload", func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
err := r.ParseMultipartForm(1 << 30) // 1 GB
if err == nil {
var file multipart.File
file, _, err = r.FormFile("uploadfile")
if err == nil {
var size int64
if sizeInterface, ok := file.(Size); ok {
size = sizeInterface.Size()
b := make([]byte, size)
file.Read(b)
md5 := md5.Sum(b)
fmt.Fprintf(w, "%x", md5)
return
}
err = errors.New("couldn't get uploaded file size")
}
}
if err != nil {
utils.DefaultLogger.Infof("Error receiving upload: %#v", err)
}
}
io.WriteString(w, `<html><body><form action="/demo/upload" method="post" enctype="multipart/form-data">
<input type="file" name="uploadfile"><br>
<input type="submit">
</form></body></html>`)
})
}
func getBuildDir() string {
_, filename, _, ok := runtime.Caller(0)
if !ok {
panic("Failed to get current frame")
}
return path.Dir(filename)
}
func main() {
// defer profile.Start().Stop()
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
// runtime.SetBlockProfileRate(1)
verbose := flag.Bool("v", false, "verbose")
bs := binds{}
flag.Var(&bs, "bind", "bind to")
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
www := flag.String("www", "/var/www", "www data")
tcp := flag.Bool("tcp", false, "also listen on TCP")
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse()
logger := utils.DefaultLogger
if *verbose {
logger.SetLogLevel(utils.LogLevelDebug)
} else {
logger.SetLogLevel(utils.LogLevelInfo)
}
logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
}
certFile := *certPath + "/fullchain.pem"
keyFile := *certPath + "/privkey.pem"
http.Handle("/", http.FileServer(http.Dir(*www)))
if len(bs) == 0 {
bs = binds{"localhost:6121"}
}
var wg sync.WaitGroup
wg.Add(len(bs))
for _, b := range bs {
bCap := b
go func() {
var err error
if *tcp {
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
} else {
server := h2quic.Server{
Server: &http.Server{Addr: bCap},
QuicConfig: &quic.Config{Versions: versions},
}
err = server.ListenAndServeTLS(certFile, keyFile)
}
if err != nil {
fmt.Println(err)
}
wg.Done()
}()
}
wg.Wait()
}

View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDqON9RKCMK4dHx
LTLiv3n+b/v4KCbYFo38rxCcv4/qoWI1Zoz4XMyPnTZPrzVMz6rhDKWTve1v7g+i
tgQDp1OoRo5ihtqCm9Dr1Ed+qb3hHtFRBBOTiWQy1Y0fvUDYGlSHQ4R7xXYVzCYk
Q9zopqr2otJp/9ZA1Yy3ATITSgcds9wJaAOpkSxCx+D7cqpGbaNtxogRZ4ZT1vVs
3l7Kr26fvnALMM4Knf2Rpq1ZjgPRoLNcfUWCZO7gr+VNCFaI2AQRaAzZkSXAk0B9
1waTtaFD87kYXl9Z8Jgx1PhJnO8WBUz/7+OAO6M+kuT78e6FltMN004Q8WmuNww6
PzmVwnmpAgMBAAECggEAAW7hpux48msZTsF5CzwisfTbdNRCEJZqvf4QOvVNGyFr
qWn8ONTQh5xtpaUrzVGD+SaLqNDDsCijvdohQih28ZOk8WNj2OK9L4Q3/8VoHQWE
QFunBwMTMuBtoaEV0XyvwbgfCmbV5yI9pYEoy9+hMisi4HUpSXJFDyWZudZ9Hqhl
FBf5UPxF1AjZViODVfnKkrWh85jaENyWjWrMEDSohzxP8IStFMK8E0feXaQb+G0f
uTelpG2JmVcO3xaYgtRVeS4p5liMQg5gE5oa2Jh7Vp3TwMhQVg0aLmslSLAYlPoh
hyBeOS3ucyFHoC/6Stnnx3jdOEf2lEUObJj3QVEeBQKBgQD3qptNY9R62UQFO8gT
pseEO5CAZRGuKG0VLPNqKKerYQreiT3xYOTPmwEy+xXzYV6DKHtlKDetHuOB862M
E1bKmjDedafQ3KMc4tywLW+U9I8GyooT8uoetzARzgoftRcMzdeu4fexWTsi+kOh
5/PeXUBnFph8E68nXHQR5+MWAwKBgQDyGnT9KUVuLvzrFAghe9Eu1iZJSDs/IrJj
we3XQ7loqMc/Qv34eVKATsgtb2cTSeTivQcSvQO+Uu9dgo17BoWAABDTaV8NAOZ6
cV2kedrWnxaTRXzB6Z5EKLg+DOMTpCV4Nf3OwzGq/mnKAe8cNM/hS+8HcZywKwr2
UwLKSq6n4wKBgQDXGUaOnVCSbZZlETnAz43i67ShvqXvY07yIDs8jRiqgLrm8b1p
oaS4JkCRXX8ABSYHtaYOAjLw2a3wVIn66WTsy6P74aWhga7szJ+tJ5kMfqal2Ey5
7LSnfqRyIkeqqCXfyfsz+S+dyQjSZRdOS90C2Gyx2+8NfC8YeXSZhJM2rwKBgQDB
pL28G/2nsrejQ2N5fLKE9s6qwLZ6ukLbHasiCc5L0uuDQw8mZcvCSsE77iYQvILx
hGYa68oJugYw0hJdu4qeJe9PWbGoEfdHKlPPEZQjJB4Hb4XpB/YJ6FPtdZtPA3Tg
4LaAYYnhjhqJc+CPvAIl3vlyB8Je+h6LhTvvF6r5JwKBgHBkpQ2Y+N3dMgfE8T38
+Vy+75i9ew3F0TB8o2MWJzYLYCzNLqrB4GkZieudoYP1Cmh50wYIifsk2zNpxOG0
27tUfMxza4ITUfmtVHS22DGWCqb1TgfQt7jajdioK4TUOZFEmiQiOYMZoFMenbzZ
miz6s3fWx1/b3wNd1J5uV2QS
-----END PRIVATE KEY-----

View File

@ -0,0 +1,158 @@
package quic
import (
"errors"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type frameSorter struct {
queue map[protocol.ByteCount][]byte
readPos protocol.ByteCount
finalOffset protocol.ByteCount
gaps *utils.ByteIntervalList
}
var errDuplicateStreamData = errors.New("Duplicate Stream Data")
func newFrameSorter() *frameSorter {
s := frameSorter{
gaps: utils.NewByteIntervalList(),
queue: make(map[protocol.ByteCount][]byte),
finalOffset: protocol.MaxByteCount,
}
s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount})
return &s
}
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, fin bool) error {
err := s.push(data, offset, fin)
if err == errDuplicateStreamData {
return nil
}
return err
}
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, fin bool) error {
if fin {
s.finalOffset = offset + protocol.ByteCount(len(data))
}
if len(data) == 0 {
return nil
}
var wasCut bool
if oldData, ok := s.queue[offset]; ok {
if len(data) <= len(oldData) {
return errDuplicateStreamData
}
data = data[len(oldData):]
offset += protocol.ByteCount(len(oldData))
wasCut = true
}
start := offset
end := offset + protocol.ByteCount(len(data))
// skip all gaps that are before this stream frame
var gap *utils.ByteIntervalElement
for gap = s.gaps.Front(); gap != nil; gap = gap.Next() {
// the frame is a duplicate. Ignore it
if end <= gap.Value.Start {
return errDuplicateStreamData
}
if end > gap.Value.Start && start <= gap.Value.End {
break
}
}
if gap == nil {
return errors.New("StreamFrameSorter BUG: no gap found")
}
if start < gap.Value.Start {
add := gap.Value.Start - start
offset += add
start += add
data = data[add:]
wasCut = true
}
// find the highest gaps whose Start lies before the end of the frame
endGap := gap
for end >= endGap.Value.End {
nextEndGap := endGap.Next()
if nextEndGap == nil {
return errors.New("StreamFrameSorter BUG: no end gap found")
}
if endGap != gap {
s.gaps.Remove(endGap)
}
if end <= nextEndGap.Value.Start {
break
}
// delete queued frames completely covered by the current frame
delete(s.queue, endGap.Value.End)
endGap = nextEndGap
}
if end > endGap.Value.End {
cutLen := end - endGap.Value.End
len := protocol.ByteCount(len(data)) - cutLen
end -= cutLen
data = data[:len]
wasCut = true
}
if start == gap.Value.Start {
if end >= gap.Value.End {
// the frame completely fills this gap
// delete the gap
s.gaps.Remove(gap)
}
if end < endGap.Value.End {
// the frame covers the beginning of the gap
// adjust the Start value to shrink the gap
endGap.Value.Start = end
}
} else if end == endGap.Value.End {
// the frame covers the end of the gap
// adjust the End value to shrink the gap
gap.Value.End = start
} else {
if gap == endGap {
// the frame lies within the current gap, splitting it into two
// insert a new gap and adjust the current one
intv := utils.ByteInterval{Start: end, End: gap.Value.End}
s.gaps.InsertAfter(intv, gap)
gap.Value.End = start
} else {
gap.Value.End = start
endGap.Value.Start = end
}
}
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
return errors.New("Too many gaps in received data")
}
if wasCut {
newData := make([]byte, len(data))
copy(newData, data)
data = newData
}
s.queue[offset] = data
return nil
}
func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
data, ok := s.queue[s.readPos]
if !ok {
return nil, s.readPos >= s.finalOffset
}
delete(s.queue, s.readPos)
s.readPos += protocol.ByteCount(len(data))
return data, s.readPos >= s.finalOffset
}

View File

@ -0,0 +1,435 @@
package quic
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("STREAM frame sorter", func() {
var s *frameSorter
checkGaps := func(expectedGaps []utils.ByteInterval) {
Expect(s.gaps.Len()).To(Equal(len(expectedGaps)))
var i int
for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
Expect(gap.Value).To(Equal(expectedGaps[i]))
i++
}
}
BeforeEach(func() {
s = newFrameSorter()
})
It("head returns nil when empty", func() {
Expect(s.Pop()).To(BeNil())
})
Context("Push", func() {
It("inserts and pops a single frame", func() {
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
data, fin := s.Pop()
Expect(data).To(Equal([]byte("foobar")))
Expect(fin).To(BeFalse())
Expect(s.Pop()).To(BeNil())
})
It("inserts and pops two consecutive frame", func() {
Expect(s.Push([]byte("foo"), 0, false)).To(Succeed())
Expect(s.Push([]byte("bar"), 3, false)).To(Succeed())
data, fin := s.Pop()
Expect(data).To(Equal([]byte("foo")))
Expect(fin).To(BeFalse())
data, fin = s.Pop()
Expect(data).To(Equal([]byte("bar")))
Expect(fin).To(BeFalse())
Expect(s.Pop()).To(BeNil())
})
It("ignores empty frames", func() {
Expect(s.Push(nil, 0, false)).To(Succeed())
Expect(s.Pop()).To(BeNil())
})
Context("FIN handling", func() {
It("saves a FIN at offset 0", func() {
Expect(s.Push(nil, 0, true)).To(Succeed())
data, fin := s.Pop()
Expect(data).To(BeEmpty())
Expect(fin).To(BeTrue())
data, fin = s.Pop()
Expect(data).To(BeNil())
Expect(fin).To(BeTrue())
})
It("saves a FIN frame at non-zero offset", func() {
Expect(s.Push([]byte("foobar"), 0, true)).To(Succeed())
data, fin := s.Pop()
Expect(data).To(Equal([]byte("foobar")))
Expect(fin).To(BeTrue())
data, fin = s.Pop()
Expect(data).To(BeNil())
Expect(fin).To(BeTrue())
})
It("sets the FIN if a stream is closed after receiving some data", func() {
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
Expect(s.Push(nil, 6, true)).To(Succeed())
data, fin := s.Pop()
Expect(data).To(Equal([]byte("foobar")))
Expect(fin).To(BeTrue())
data, fin = s.Pop()
Expect(data).To(BeNil())
Expect(fin).To(BeTrue())
})
})
Context("Gap handling", func() {
It("finds the first gap", func() {
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 0, End: 10},
{Start: 16, End: protocol.MaxByteCount},
})
})
It("correctly sets the first gap for a frame with offset 0", func() {
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 6, End: protocol.MaxByteCount},
})
})
It("finds the two gaps", func() {
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 20, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 0, End: 10},
{Start: 16, End: 20},
{Start: 26, End: protocol.MaxByteCount},
})
})
It("finds the two gaps in reverse order", func() {
Expect(s.Push([]byte("foobar"), 20, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 0, End: 10},
{Start: 16, End: 20},
{Start: 26, End: protocol.MaxByteCount},
})
})
It("shrinks a gap when it is partially filled", func() {
Expect(s.Push([]byte("test"), 10, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 4, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 0, End: 4},
{Start: 14, End: protocol.MaxByteCount},
})
})
It("deletes a gap at the beginning, when it is filled", func() {
Expect(s.Push([]byte("test"), 6, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 10, End: protocol.MaxByteCount},
})
})
It("deletes a gap in the middle, when it is filled", func() {
Expect(s.Push([]byte("test"), 0, false)).To(Succeed())
Expect(s.Push([]byte("test2"), 10, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 4, false)).To(Succeed())
Expect(s.queue).To(HaveLen(3))
checkGaps([]utils.ByteInterval{
{Start: 15, End: protocol.MaxByteCount},
})
})
It("splits a gap into two", func() {
Expect(s.Push([]byte("test"), 100, false)).To(Succeed())
Expect(s.Push([]byte("foobar"), 50, false)).To(Succeed())
Expect(s.queue).To(HaveLen(2))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 50},
{Start: 56, End: 100},
{Start: 104, End: protocol.MaxByteCount},
})
})
Context("Overlapping Stream Data detection", func() {
// create gaps: 0-5, 10-15, 20-25, 30-inf
BeforeEach(func() {
Expect(s.Push([]byte("12345"), 5, false)).To(Succeed())
Expect(s.Push([]byte("12345"), 15, false)).To(Succeed())
Expect(s.Push([]byte("12345"), 25, false)).To(Succeed())
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 10, End: 15},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("cuts a frame with offset 0 that overlaps at the end", func() {
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(0)))
Expect(s.queue[0]).To(Equal([]byte("fooba")))
Expect(s.queue[0]).To(HaveCap(5))
checkGaps([]utils.ByteInterval{
{Start: 10, End: 15},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("cuts a frame that overlaps at the end", func() {
// 4 to 7
Expect(s.Push([]byte("foo"), 4, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(4)))
Expect(s.queue[4]).To(Equal([]byte("f")))
Expect(s.queue[4]).To(HaveCap(1))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 4},
{Start: 10, End: 15},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("cuts a frame that completely fills a gap, but overlaps at the end", func() {
// 10 to 16
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
Expect(s.queue[10]).To(Equal([]byte("fooba")))
Expect(s.queue[10]).To(HaveCap(5))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("cuts a frame that overlaps at the beginning", func() {
// 8 to 14
Expect(s.Push([]byte("foobar"), 8, false)).To(Succeed())
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(8)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
Expect(s.queue[10]).To(Equal([]byte("obar")))
Expect(s.queue[10]).To(HaveCap(4))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 14, End: 15},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that overlaps at the beginning and at the end, starting in a gap", func() {
// 2 to 12
Expect(s.Push([]byte("1234567890"), 2, false)).To(Succeed())
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
Expect(s.queue[2]).To(Equal([]byte("1234567890")))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 2},
{Start: 12, End: 15},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that overlaps at the beginning and at the end, starting in a gap, ending in data", func() {
// 2 to 17
Expect(s.Push([]byte("123456789012345"), 2, false)).To(Succeed())
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
Expect(s.queue[2]).To(Equal([]byte("1234567890123")))
Expect(s.queue[2]).To(HaveCap(13))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 2},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that overlaps at the beginning and at the end, starting in a gap, ending in data", func() {
// 5 to 22
Expect(s.Push([]byte("12345678901234567"), 5, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(5)))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
Expect(s.queue[10]).To(Equal([]byte("678901234567")))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 22, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that closes multiple gaps", func() {
// 2 to 27
Expect(s.Push(bytes.Repeat([]byte{'e'}, 25), 2, false)).To(Succeed())
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(25)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
Expect(s.queue[2]).To(Equal(bytes.Repeat([]byte{'e'}, 23)))
Expect(s.queue[2]).To(HaveCap(23))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 2},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that closes multiple gaps", func() {
// 5 to 27
Expect(s.Push(bytes.Repeat([]byte{'d'}, 22), 5, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(5)))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(25)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
Expect(s.queue[10]).To(Equal(bytes.Repeat([]byte{'d'}, 15)))
Expect(s.queue[10]).To(HaveCap(15))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that covers multiple gaps and ends at the end of a gap", func() {
data := bytes.Repeat([]byte{'e'}, 14)
// 1 to 15
Expect(s.Push(data, 1, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(1)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(15)))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
Expect(s.queue[1]).To(Equal(data))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 1},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("processes a frame that closes all gaps (except for the last one)", func() {
data := bytes.Repeat([]byte{'f'}, 32)
// 0 to 32
Expect(s.Push(data, 0, false)).To(Succeed())
Expect(s.queue).To(HaveLen(1))
Expect(s.queue).To(HaveKey(protocol.ByteCount(0)))
Expect(s.queue[0]).To(Equal(data))
checkGaps([]utils.ByteInterval{
{Start: 32, End: protocol.MaxByteCount},
})
})
It("cuts a frame that overlaps at the beginning and at the end, starting in data already received", func() {
// 8 to 17
Expect(s.Push([]byte("123456789"), 8, false)).To(Succeed())
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(8)))
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
Expect(s.queue[10]).To(Equal([]byte("34567")))
Expect(s.queue[10]).To(HaveCap(5))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
It("cuts a frame that completely covers two gaps", func() {
// 10 to 20
Expect(s.Push([]byte("1234567890"), 10, false)).To(Succeed())
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
Expect(s.queue[10]).To(Equal([]byte("12345")))
Expect(s.queue[10]).To(HaveCap(5))
checkGaps([]utils.ByteInterval{
{Start: 0, End: 5},
{Start: 20, End: 25},
{Start: 30, End: protocol.MaxByteCount},
})
})
})
Context("duplicate data", func() {
expectedGaps := []utils.ByteInterval{
{Start: 5, End: 10},
{Start: 15, End: protocol.MaxByteCount},
}
BeforeEach(func() {
// create gaps: 5-10, 15-inf
Expect(s.Push([]byte("12345"), 0, false)).To(Succeed())
Expect(s.Push([]byte("12345"), 10, false)).To(Succeed())
checkGaps(expectedGaps)
})
AfterEach(func() {
// check that the gaps were not modified
checkGaps(expectedGaps)
})
It("does not modify data when receiving a duplicate", func() {
err := s.push([]byte("fffff"), 0, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[0]).ToNot(Equal([]byte("fffff")))
})
It("detects a duplicate frame that is smaller than the original, starting at the beginning", func() {
// 10 to 12
err := s.push([]byte("12"), 10, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[10]).To(HaveLen(5))
})
It("detects a duplicate frame that is smaller than the original, somewhere in the middle", func() {
// 1 to 4
err := s.push([]byte("123"), 1, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[0]).To(HaveLen(5))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1)))
})
It("detects a duplicate frame that is smaller than the original, somewhere in the middle in the last block", func() {
// 11 to 14
err := s.push([]byte("123"), 11, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[10]).To(HaveLen(5))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(11)))
})
It("detects a duplicate frame that is smaller than the original, with aligned end in the last block", func() {
// 11 to 15
err := s.push([]byte("1234"), 1, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[10]).To(HaveLen(5))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(11)))
})
It("detects a duplicate frame that is smaller than the original, with aligned end", func() {
// 3 to 5
err := s.push([]byte("12"), 3, false)
Expect(err).To(MatchError(errDuplicateStreamData))
Expect(s.queue[0]).To(HaveLen(5))
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(3)))
})
})
Context("DoS protection", func() {
It("errors when too many gaps are created", func() {
for i := 0; i < protocol.MaxStreamFrameSorterGaps; i++ {
Expect(s.Push([]byte("foobar"), protocol.ByteCount(i*7), false)).To(Succeed())
}
Expect(s.gaps.Len()).To(Equal(protocol.MaxStreamFrameSorterGaps))
err := s.Push([]byte("foobar"), protocol.ByteCount(protocol.MaxStreamFrameSorterGaps*7)+100, false)
Expect(err).To(MatchError("Too many gaps in received data"))
})
})
})
})
})

View File

@ -0,0 +1,314 @@
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type roundTripperOpts struct {
DisableCompression bool
}
var dialAddr = quic.DialAddr
// client is a HTTP2 client doing QUIC requests
type client struct {
mutex sync.RWMutex
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
hostname string
handshakeErr error
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
session quic.Session
headerStream quic.Stream
headerErr *qerr.QuicError
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response
logger utils.Logger
}
var _ http.RoundTripper = &client{}
var defaultQuicConfig = &quic.Config{
RequestConnectionIDOmission: true,
KeepAlive: true,
}
// newClient creates a new client
func newClient(
hostname string,
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
tlsConf: tlsConfig,
config: config,
opts: opts,
headerErrored: make(chan struct{}),
dialer: dialer,
logger: utils.DefaultLogger.WithPrefix("client"),
}
}
// dial dials the connection
func (c *client) dial() error {
var err error
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
// once the version has been negotiated, open the header stream
c.headerStream, err = c.session.OpenStream()
if err != nil {
return err
}
c.requestWriter = newRequestWriter(c.headerStream, c.logger)
go c.handleHeaderStream()
return nil
}
func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream)
var err error
for err == nil {
err = c.readResponse(h2framer, decoder)
}
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
c.logger.Debugf("Error handling header stream: %s", err)
}
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request
close(c.headerErrored)
}
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
frame, err := h2framer.ReadFrame()
if err != nil {
return err
}
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
return errors.New("not a headers frame")
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
return fmt.Errorf("cannot read header fields: %s", err.Error())
}
c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
return err
}
responseChan <- rsp
return nil
}
// Roundtrip executes a request and returns a response
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: add port to address, if it doesn't have one
if req.URL.Scheme != "https" {
return nil, errors.New("quic http2: unsupported scheme")
}
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
}
c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
})
if c.handshakeErr != nil {
return nil, c.handshakeErr
}
hasBody := (req.Body != nil)
responseChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync()
if err != nil {
_ = c.closeWithError(err)
return nil, err
}
c.mutex.Lock()
c.responses[dataStream.StreamID()] = responseChan
c.mutex.Unlock()
var requestedGzip bool
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true
}
// TODO: add support for trailers
endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil {
_ = c.closeWithError(err)
return nil, err
}
resc := make(chan error, 1)
if hasBody {
go func() {
resc <- c.writeRequestBody(dataStream, req.Body)
}()
}
var res *http.Response
var receivedResponse bool
var bodySent bool
if !hasBody {
bodySent = true
}
ctx := req.Context()
for !(bodySent && receivedResponse) {
select {
case res = <-responseChan:
receivedResponse = true
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
case err := <-resc:
bodySent = true
if err != nil {
return nil, err
}
case <-ctx.Done():
// error code 6 signals that stream was canceled
dataStream.CancelRead(6)
dataStream.CancelWrite(6)
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
return nil, ctx.Err()
case <-c.headerErrored:
// an error occurred on the header stream
_ = c.closeWithError(c.headerErr)
return nil, c.headerErr
}
}
// TODO: correctly set this variable
var streamEnded bool
isHead := (req.Method == "HEAD")
res = setLength(res, isHead, streamEnded)
if streamEnded || isHead {
res.Body = noBody
} else {
res.Body = dataStream
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = &gzipReader{body: res.Body}
res.Uncompressed = true
}
}
res.Request = req
return res, nil
}
func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
defer func() {
cerr := body.Close()
if err == nil {
// TODO: what to do with dataStream here? Maybe reset it?
err = cerr
}
}()
_, err = io.Copy(dataStream, body)
if err != nil {
// TODO: what to do with dataStream here? Maybe reset it?
return err
}
return dataStream.Close()
}
func (c *client) closeWithError(e error) error {
if c.session == nil {
return nil
}
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
}
// Close closes the client
func (c *client) Close() error {
if c.session == nil {
return nil
}
return c.session.Close()
}
// copied from net/transport.go
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}

View File

@ -0,0 +1,639 @@
package h2quic
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"io"
"net/http"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Client", func() {
var (
client *client
session *mockSession
headerStream *mockStream
req *http.Request
origDialAddr = dialAddr
)
injectResponse := func(id protocol.StreamID, rsp *http.Response) {
EventuallyWithOffset(0, func() bool {
client.mutex.Lock()
defer client.mutex.Unlock()
_, ok := client.responses[id]
return ok
}).Should(BeTrue())
rspChan := client.responses[5]
ExpectWithOffset(0, rspChan).ToNot(BeClosed())
rspChan <- rsp
}
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal(hostname))
session = newMockSession()
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
client.session = session
headerStream = newMockStream(3)
client.headerStream = headerStream
client.requestWriter = newRequestWriter(headerStream, utils.DefaultLogger)
var err error
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("saves the TLS config", func() {
tlsConf := &tls.Config{InsecureSkipVerify: true}
client = newClient("", tlsConf, &roundTripperOpts{}, nil, nil)
Expect(client.tlsConf).To(Equal(tlsConf))
})
It("saves the QUIC config", func() {
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf, nil)
Expect(client.config).To(Equal(quicConf))
})
It("uses the default QUIC config if none is give", func() {
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil, nil)
Expect(client.config).ToNot(BeNil())
Expect(client.config).To(Equal(defaultQuicConfig))
})
It("adds the port to the hostname, if none is given", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
})
It("dials", func() {
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
close(headerStream.unblockRead)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
// fmt.Println("done")
}()
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("errors when dialing fails", func() {
testErr := errors.New("handshake error")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("uses the custom dialer, if provided", func() {
var tlsCfg *tls.Config
var qCfg *quic.Config
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
tlsCfg = tlsCfgP
qCfg = cfg
return session, nil
}
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, dialer)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
Expect(qCfg).To(Equal(client.config))
Expect(tlsCfg).To(Equal(client.tlsConf))
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("errors if it can't open a stream", func() {
testErr := errors.New("you shall not pass")
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
session.streamOpenErr = testErr
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
_, err := client.RoundTrip(req)
Expect(err).To(MatchError(testErr))
})
It("returns a request when dial fails", func() {
testErr := errors.New("dial error")
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return nil, testErr
}
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
_, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
close(done)
}()
_, err = client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Eventually(done).Should(BeClosed())
})
Context("Doing requests", func() {
var request *http.Request
var dataStream *mockStream
getRequest := func(data []byte) *http2.MetaHeadersFrame {
r := bytes.NewReader(data)
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, r)
frame, err := h2framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)}
mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
return mhframe
}
getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string {
fields := make(map[string]string)
for _, hf := range f.Fields {
fields[hf.Name] = hf.Value
}
return fields
}
BeforeEach(func() {
var err error
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
return session, nil
}
dataStream = newMockStream(5)
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})
It("does a request", func() {
teapot := &http.Response{
Status: "418 I'm a teapot",
StatusCode: 418,
}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp).To(Equal(teapot))
Expect(rsp.Body).To(Equal(dataStream))
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
Expect(rsp.Request).To(Equal(request))
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
injectResponse(5, teapot)
Expect(client.headerErrored).ToNot(BeClosed())
Eventually(done).Should(BeClosed())
})
It("errors if a request without a body is canceled", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
cancel()
Eventually(done).Should(BeClosed())
Expect(dataStream.reset).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("errors if a request with a body is canceled after the body is sent", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
request.Body = &mockBody{}
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
time.Sleep(10 * time.Millisecond)
cancel()
Eventually(done).Should(BeClosed())
Expect(dataStream.reset).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("errors if a request with a body is canceled before the body is sent", func() {
done := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer GinkgoRecover()
request = request.WithContext(ctx)
request.Body = &mockBody{}
cancel()
time.Sleep(10 * time.Millisecond)
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(context.Canceled))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(dataStream.reset).To(BeTrue())
Expect(dataStream.canceledWrite).To(BeTrue())
Expect(client.headerErrored).ToNot(BeClosed())
})
It("closes the quic client when encountering an error on the header stream", func() {
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(client.headerErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
})
It("returns subsequent request if there was an error on the header stream before", func() {
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
_, err := client.RoundTrip(request)
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
// now that the first request failed due to an error on the header stream, try another request
_, nextErr := client.RoundTrip(request)
Expect(nextErr).To(MatchError(err))
})
It("blocks if no stream is available", func() {
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
session.blockOpenStreamSync = true
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
client.Close()
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
Context("validating the address", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("h2quic Client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
It("refuses to do plain HTTP requests", func() {
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
Expect(err).To(MatchError("quic http2: unsupported scheme"))
})
It("adds the port for request URLs without one", func() {
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
// the client.RoundTrip will block, because the encryption level is still set to Unencrypted
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Consistently(done).ShouldNot(BeClosed())
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
})
It("sets the EndStream header for requests without a body", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RoundTrip(request)
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
mhf := getRequest(headerStream.dataWritten.Bytes())
Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue())
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("sets the EndStream header to false for requests with a body", func() {
request.Body = &mockBody{}
done := make(chan struct{})
go func() {
defer GinkgoRecover()
client.RoundTrip(request)
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
mhf := getRequest(headerStream.dataWritten.Bytes())
Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse())
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
Context("requests containing a Body", func() {
var requestBody []byte
var response *http.Response
BeforeEach(func() {
requestBody = []byte("request body")
body := &mockBody{}
body.SetData(requestBody)
request.Body = body
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
// fake a handshake
client.dialOnce.Do(func() {})
session.streamsToOpen = []quic.Stream{dataStream}
})
It("sends a request", func() {
rspChan := make(chan *http.Response)
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
rspChan <- rsp
}()
injectResponse(5, response)
Eventually(rspChan).Should(Receive(Equal(response)))
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
Expect(dataStream.closed).To(BeTrue())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
It("returns the error that occurred when reading the body", func() {
testErr := errors.New("testErr")
request.Body.(*mockBody).readErr = testErr
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
It("returns the error that occurred when closing the body", func() {
testErr := errors.New("testErr")
request.Body.(*mockBody).closeErr = testErr
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).To(MatchError(testErr))
Expect(rsp).To(BeNil())
close(done)
}()
Eventually(done).Should(BeClosed())
Expect(request.Body.(*mockBody).closed).To(BeTrue())
})
})
Context("gzip compression", func() {
var gzippedData []byte // a gzipped foobar
var response *http.Response
BeforeEach(func() {
var b bytes.Buffer
w := gzip.NewWriter(&b)
w.Write([]byte("foobar"))
w.Close()
gzippedData = b.Bytes()
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
})
It("adds the gzip header to requests", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp).ToNot(BeNil())
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
Expect(rsp.Header.Get("Content-Length")).To(BeEmpty())
data := make([]byte, 6)
_, err = io.ReadFull(rsp.Body, data)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal([]byte("foobar")))
close(done)
}()
dataStream.dataToRead.Write(gzippedData)
response.Header.Add("Content-Encoding", "gzip")
injectResponse(5, response)
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
close(dataStream.unblockRead)
Eventually(done).Should(BeClosed())
})
It("doesn't add gzip if the header disable it", func() {
client.opts.DisableCompression = true
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
close(done)
}()
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).ToNot(HaveKey("accept-encoding"))
// make the go routine return
injectResponse(5, &http.Response{})
Eventually(done).Should(BeClosed())
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
Expect(rsp).ToNot(BeNil())
data := make([]byte, 11)
rsp.Body.Read(data)
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(data).To(Equal([]byte("not gzipped")))
close(done)
}()
dataStream.dataToRead.Write([]byte("not gzipped"))
injectResponse(5, response)
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
Eventually(done).Should(BeClosed())
})
It("doesn't add the gzip header for requests that have the accept-enconding set", func() {
request.Header.Add("accept-encoding", "gzip")
done := make(chan struct{})
go func() {
defer GinkgoRecover()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
data := make([]byte, 12)
_, err = rsp.Body.Read(data)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(data).To(Equal([]byte("gzipped data")))
close(done)
}()
dataStream.dataToRead.Write([]byte("gzipped data"))
injectResponse(5, response)
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
Eventually(done).Should(BeClosed())
})
})
Context("handling the header stream", func() {
var h2framer *http2.Framer
BeforeEach(func() {
h2framer = http2.NewFramer(&headerStream.dataToRead, nil)
client.responses[23] = make(chan *http.Response)
})
It("reads header values from a response", func() {
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
data := []byte{0x48, 0x03, 0x33, 0x30, 0x32, 0x58, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x61, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x31, 0x20, 0x4f, 0x63, 0x74, 0x20, 0x32, 0x30, 0x31, 0x33, 0x20, 0x32, 0x30, 0x3a, 0x31, 0x33, 0x3a, 0x32, 0x31, 0x20, 0x47, 0x4d, 0x54, 0x6e, 0x17, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d}
headerStream.dataToRead.Write([]byte{0x0, 0x0, byte(len(data)), 0x1, 0x5, 0x0, 0x0, 0x0, 23})
headerStream.dataToRead.Write(data)
go client.handleHeaderStream()
var rsp *http.Response
Eventually(client.responses[23]).Should(Receive(&rsp))
Expect(rsp).ToNot(BeNil())
Expect(rsp.Proto).To(Equal("HTTP/2.0"))
Expect(rsp.ProtoMajor).To(BeEquivalentTo(2))
Expect(rsp.StatusCode).To(BeEquivalentTo(302))
Expect(rsp.Status).To(Equal("302 Found"))
Expect(rsp.Header).To(HaveKeyWithValue("Location", []string{"https://www.example.com"}))
Expect(rsp.Header).To(HaveKeyWithValue("Cache-Control", []string{"private"}))
})
It("errors if the H2 frame is not a HeadersFrame", func() {
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
})
It("errors if it can't read the HPACK encoded header fields", func() {
h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 23,
EndHeaders: true,
BlockFragment: []byte("invalid HPACK data"),
})
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
})
It("errors if the stream cannot be found", func() {
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1337,
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
Expect(err).ToNot(HaveOccurred())
client.handleHeaderStream()
Eventually(client.headerErrored).Should(BeClosed())
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
})
})
})
})

View File

@ -0,0 +1,35 @@
package h2quic
// copied from net/transport.go
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
import (
"compress/gzip"
"io"
)
// call gzip.NewReader on the first call to Read
type gzipReader struct {
body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
}
func (gz *gzipReader) Close() error {
return gz.body.Close()
}

View File

@ -0,0 +1,13 @@
package h2quic
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestH2quic(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "H2quic Suite")
}

View File

@ -0,0 +1,77 @@
package h2quic
import (
"crypto/tls"
"errors"
"net/http"
"net/url"
"strconv"
"strings"
"golang.org/x/net/http2/hpack"
)
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
var path, authority, method, contentLengthStr string
httpHeaders := http.Header{}
for _, h := range headers {
switch h.Name {
case ":path":
path = h.Value
case ":method":
method = h.Value
case ":authority":
authority = h.Value
case "content-length":
contentLengthStr = h.Value
default:
if !h.IsPseudo() {
httpHeaders.Add(h.Name, h.Value)
}
}
}
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
if len(httpHeaders["Cookie"]) > 0 {
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
}
if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
return nil, errors.New(":path, :authority and :method must not be empty")
}
u, err := url.Parse(path)
if err != nil {
return nil, err
}
var contentLength int64
if len(contentLengthStr) > 0 {
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return nil, err
}
}
return &http.Request{
Method: method,
URL: u,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
Header: httpHeaders,
Body: nil,
ContentLength: contentLength,
Host: authority,
RequestURI: path,
TLS: &tls.ConnectionState{},
}, nil
}
func hostnameFromRequest(req *http.Request) string {
if req.URL != nil {
return req.URL.Host
}
return ""
}

View File

@ -0,0 +1,29 @@
package h2quic
import (
"io"
quic "github.com/lucas-clemente/quic-go"
)
type requestBody struct {
requestRead bool
dataStream quic.Stream
}
// make sure the requestBody can be used as a http.Request.Body
var _ io.ReadCloser = &requestBody{}
func newRequestBody(stream quic.Stream) *requestBody {
return &requestBody{dataStream: stream}
}
func (b *requestBody) Read(p []byte) (int, error) {
b.requestRead = true
return b.dataStream.Read(p)
}
func (b *requestBody) Close() error {
// stream's Close() closes the write side, not the read side
return nil
}

View File

@ -0,0 +1,39 @@
package h2quic
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Request body", func() {
var (
stream *mockStream
rb *requestBody
)
BeforeEach(func() {
stream = &mockStream{}
stream.dataToRead.Write([]byte("foobar")) // provides data to be read
rb = newRequestBody(stream)
})
It("reads from the stream", func() {
b := make([]byte, 10)
n, _ := stream.Read(b)
Expect(n).To(Equal(6))
Expect(b[0:6]).To(Equal([]byte("foobar")))
})
It("saves if the stream was read from", func() {
Expect(rb.requestRead).To(BeFalse())
rb.Read(make([]byte, 1))
Expect(rb.requestRead).To(BeTrue())
})
It("doesn't close the stream when closing the request body", func() {
Expect(stream.closed).To(BeFalse())
err := rb.Close()
Expect(err).ToNot(HaveOccurred())
Expect(stream.closed).To(BeFalse())
})
})

View File

@ -0,0 +1,121 @@
package h2quic
import (
"net/http"
"net/url"
"golang.org/x/net/http2/hpack"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Request", func() {
It("populates request", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
{Name: "content-length", Value: "42"},
}
req, err := requestFromHeaders(headers)
Expect(err).NotTo(HaveOccurred())
Expect(req.Method).To(Equal("GET"))
Expect(req.URL.Path).To(Equal("/foo"))
Expect(req.Proto).To(Equal("HTTP/2.0"))
Expect(req.ProtoMajor).To(Equal(2))
Expect(req.ProtoMinor).To(Equal(0))
Expect(req.ContentLength).To(Equal(int64(42)))
Expect(req.Header).To(BeEmpty())
Expect(req.Body).To(BeNil())
Expect(req.Host).To(Equal("quic.clemente.io"))
Expect(req.RequestURI).To(Equal("/foo"))
Expect(req.TLS).ToNot(BeNil())
})
It("concatenates the cookie headers", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
{Name: "cookie", Value: "cookie1=foobar1"},
{Name: "cookie", Value: "cookie2=foobar2"},
}
req, err := requestFromHeaders(headers)
Expect(err).NotTo(HaveOccurred())
Expect(req.Header).To(Equal(http.Header{
"Cookie": []string{"cookie1=foobar1; cookie2=foobar2"},
}))
})
It("handles other headers", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
{Name: "cache-control", Value: "max-age=0"},
{Name: "duplicate-header", Value: "1"},
{Name: "duplicate-header", Value: "2"},
}
req, err := requestFromHeaders(headers)
Expect(err).NotTo(HaveOccurred())
Expect(req.Header).To(Equal(http.Header{
"Cache-Control": []string{"max-age=0"},
"Duplicate-Header": []string{"1", "2"},
}))
})
It("errors with missing path", func() {
headers := []hpack.HeaderField{
{Name: ":authority", Value: "quic.clemente.io"},
{Name: ":method", Value: "GET"},
}
_, err := requestFromHeaders(headers)
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
})
It("errors with missing method", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":authority", Value: "quic.clemente.io"},
}
_, err := requestFromHeaders(headers)
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
})
It("errors with missing authority", func() {
headers := []hpack.HeaderField{
{Name: ":path", Value: "/foo"},
{Name: ":method", Value: "GET"},
}
_, err := requestFromHeaders(headers)
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
})
Context("extracting the hostname from a request", func() {
var url *url.URL
BeforeEach(func() {
var err error
url, err = url.Parse("https://quic.clemente.io:1337")
Expect(err).ToNot(HaveOccurred())
})
It("uses req.URL.Host", func() {
req := &http.Request{URL: url}
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
})
It("uses req.URL.Host even if req.Host is available", func() {
req := &http.Request{
Host: "www.example.org",
URL: url,
}
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
})
It("returns an empty hostname if nothing is set", func() {
Expect(hostnameFromRequest(&http.Request{})).To(BeEmpty())
})
})
})

View File

@ -0,0 +1,203 @@
package h2quic
import (
"bytes"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type requestWriter struct {
mutex sync.Mutex
headerStream quic.Stream
henc *hpack.Encoder
hbuf bytes.Buffer // HPACK encoder writes into this
logger utils.Logger
}
const defaultUserAgent = "quic-go"
func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
rw := &requestWriter{
headerStream: headerStream,
logger: logger,
}
rw.henc = hpack.NewEncoder(&rw.hbuf)
return rw
}
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
// TODO: add support for trailers
// TODO: add support for gzip compression
// TODO: write continuation frames, if the header frame is too long
w.mutex.Lock()
defer w.mutex.Unlock()
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
h2framer := http2.NewFramer(w.headerStream, nil)
return h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: uint32(dataStreamID),
EndHeaders: true,
EndStream: endStream,
BlockFragment: w.hbuf.Bytes(),
Priority: http2.PriorityParam{Weight: 0xff},
})
}
// the rest of this files is copied from http2.Transport
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
w.hbuf.Reset()
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return nil, err
}
var path string
if req.Method != "CONNECT" {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
}
return nil, fmt.Errorf("invalid request :path %q", orig)
}
}
}
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}
}
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
w.writeHeader(":authority", host)
w.writeHeader(":method", req.Method)
if req.Method != "CONNECT" {
w.writeHeader(":path", path)
w.writeHeader(":scheme", req.URL.Scheme)
}
if trailers != "" {
w.writeHeader("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
switch lowKey {
case "host", "content-length":
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
case "user-agent":
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
w.writeHeader(lowKey, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
w.writeHeader("accept-encoding", "gzip")
}
if !didUA {
w.writeHeader("user-agent", defaultUserAgent)
}
return w.hbuf.Bytes(), nil
}
func (w *requestWriter) writeHeader(name, value string) {
w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
}
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
func actualContentLength(req *http.Request) int64 {
if req.Body == nil {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}

View File

@ -0,0 +1,118 @@
package h2quic
import (
"bytes"
"net/http"
"net/url"
"strconv"
"strings"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Request", func() {
var (
rw *requestWriter
headerStream *mockStream
decoder *hpack.Decoder
)
BeforeEach(func() {
headerStream = &mockStream{}
rw = newRequestWriter(headerStream, utils.DefaultLogger)
decoder = hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
})
decode := func(p []byte) (*http2.HeadersFrame, map[string] /* HeaderField.Name */ string /* HeaderField.Value */) {
framer := http2.NewFramer(nil, bytes.NewReader(p))
frame, err := framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
headerFrame := frame.(*http2.HeadersFrame)
fields, err := decoder.DecodeFull(headerFrame.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
values := make(map[string]string)
for _, headerField := range fields {
values[headerField.Name] = headerField.Value
}
return headerFrame, values
}
It("writes a GET request", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, true, false)
headerFrame, headerFields := decode(headerStream.dataWritten.Bytes())
Expect(headerFrame.StreamID).To(Equal(uint32(1337)))
Expect(headerFrame.HasPriority()).To(BeTrue())
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar"))
Expect(headerFields).To(HaveKeyWithValue(":scheme", "https"))
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
})
It("sets the EndStream header", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, true, false)
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
Expect(headerFrame.StreamEnded()).To(BeTrue())
})
It("doesn't set the EndStream header, if requested", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, false, false)
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
Expect(headerFrame.StreamEnded()).To(BeFalse())
})
It("requests gzip compression, if requested", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 1337, true, true)
_, headerFields := decode(headerStream.dataWritten.Bytes())
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
It("writes a POST request", func() {
form := url.Values{}
form.Add("foo", "bar")
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", strings.NewReader(form.Encode()))
Expect(err).ToNot(HaveOccurred())
rw.WriteRequest(req, 5, true, false)
_, headerFields := decode(headerStream.dataWritten.Bytes())
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
Expect(headerFields).To(HaveKey("content-length"))
contentLength, err := strconv.Atoi(headerFields["content-length"])
Expect(err).ToNot(HaveOccurred())
Expect(contentLength).To(BeNumerically(">", 0))
})
It("sends cookies", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
cookie1 := &http.Cookie{
Name: "Cookie #1",
Value: "Value #1",
}
cookie2 := &http.Cookie{
Name: "Cookie #2",
Value: "Value #2",
}
req.AddCookie(cookie1)
req.AddCookie(cookie2)
rw.WriteRequest(req, 11, true, false)
_, headerFields := decode(headerStream.dataWritten.Bytes())
// TODO(lclemente): Remove Or() once we drop support for Go 1.8.
Expect(headerFields).To(Or(
HaveKeyWithValue("cookie", "Cookie #1=Value #1; Cookie #2=Value #2"),
HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`),
))
})
})

View File

@ -0,0 +1,95 @@
package h2quic
import (
"bytes"
"errors"
"io/ioutil"
"net/http"
"net/textproto"
"strconv"
"strings"
"golang.org/x/net/http2"
)
// copied from net/http2/transport.go
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
var noBody = ioutil.NopCloser(bytes.NewReader(nil))
// from the handleResponse function
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
if f.Truncated {
return nil, errResponseHeaderListSize
}
status := f.PseudoValue("status")
if status == "" {
return nil, errors.New("missing status pseudo header")
}
statusCode, err := strconv.Atoi(status)
if err != nil {
return nil, errors.New("malformed non-numeric status pseudo header")
}
// TODO: handle statusCode == 100
header := make(http.Header)
res := &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: header,
StatusCode: statusCode,
Status: status + " " + http.StatusText(statusCode),
}
for _, hf := range f.RegularFields() {
key := http.CanonicalHeaderKey(hf.Name)
if key == "Trailer" {
t := res.Trailer
if t == nil {
t = make(http.Header)
res.Trailer = t
}
foreachHeaderElement(hf.Value, func(v string) {
t[http.CanonicalHeaderKey(v)] = nil
})
} else {
header[key] = append(header[key], hf.Value)
}
}
return res, nil
}
// continuation of the handleResponse function
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
if !streamEnded || isHead {
res.ContentLength = -1
if clens := res.Header["Content-Length"]; len(clens) == 1 {
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
res.ContentLength = clen64
}
}
}
return res
}
// copied from net/http/server.go
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
return
}
if !strings.Contains(v, ",") {
fn(v)
return
}
for _, f := range strings.Split(v, ",") {
if f = textproto.TrimString(f); f != "" {
fn(f)
}
}
}

View File

@ -0,0 +1,114 @@
package h2quic
import (
"bytes"
"net/http"
"strconv"
"strings"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
type responseWriter struct {
dataStreamID protocol.StreamID
dataStream quic.Stream
headerStream quic.Stream
headerStreamMutex *sync.Mutex
header http.Header
status int // status code passed to WriteHeader
headerWritten bool
logger utils.Logger
}
func newResponseWriter(
headerStream quic.Stream,
headerStreamMutex *sync.Mutex,
dataStream quic.Stream,
dataStreamID protocol.StreamID,
logger utils.Logger,
) *responseWriter {
return &responseWriter{
header: http.Header{},
headerStream: headerStream,
headerStreamMutex: headerStreamMutex,
dataStream: dataStream,
dataStreamID: dataStreamID,
logger: logger,
}
}
func (w *responseWriter) Header() http.Header {
return w.header
}
func (w *responseWriter) WriteHeader(status int) {
if w.headerWritten {
return
}
w.headerWritten = true
w.status = status
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
for k, v := range w.header {
for index := range v {
enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
}
}
w.logger.Infof("Responding with %d", status)
w.headerStreamMutex.Lock()
defer w.headerStreamMutex.Unlock()
h2framer := http2.NewFramer(w.headerStream, nil)
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: uint32(w.dataStreamID),
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
if err != nil {
w.logger.Errorf("could not write h2 header: %s", err.Error())
}
}
func (w *responseWriter) Write(p []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(200)
}
if !bodyAllowedForStatus(w.status) {
return 0, http.ErrBodyNotAllowed
}
return w.dataStream.Write(p)
}
func (w *responseWriter) Flush() {}
// This is a NOP. Use http.Request.Context
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
// test that we implement http.Flusher
var _ http.Flusher = &responseWriter{}
// copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}

View File

@ -0,0 +1,9 @@
package h2quic
import "net/http"
// The CloseNotifier is a deprecated interface, and staticcheck will report that from Go 1.11.
// By defining it in a separate file, we can exclude this file from staticcheck.
// test that we implement http.CloseNotifier
var _ http.CloseNotifier = &responseWriter{}

View File

@ -0,0 +1,163 @@
package h2quic
import (
"bytes"
"context"
"io"
"net/http"
"sync"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockStream struct {
id protocol.StreamID
dataToRead bytes.Buffer
dataWritten bytes.Buffer
reset bool
canceledWrite bool
closed bool
remoteClosed bool
unblockRead chan struct{}
ctx context.Context
ctxCancel context.CancelFunc
}
var _ quic.Stream = &mockStream{}
func newMockStream(id protocol.StreamID) *mockStream {
s := &mockStream{
id: id,
unblockRead: make(chan struct{}),
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
return s
}
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil }
func (s *mockStream) CancelWrite(quic.ErrorCode) error { s.canceledWrite = true; return nil }
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
func (s mockStream) StreamID() protocol.StreamID { return s.id }
func (s *mockStream) Context() context.Context { return s.ctx }
func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") }
func (s *mockStream) Read(p []byte) (int, error) {
n, _ := s.dataToRead.Read(p)
if n == 0 { // block if there's no data
<-s.unblockRead
return 0, io.EOF
}
return n, nil // never return an EOF
}
func (s *mockStream) Write(p []byte) (int, error) { return s.dataWritten.Write(p) }
var _ = Describe("Response Writer", func() {
var (
w *responseWriter
headerStream *mockStream
dataStream *mockStream
)
BeforeEach(func() {
headerStream = &mockStream{}
dataStream = &mockStream{}
w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5, utils.DefaultLogger)
})
decodeHeaderFields := func() map[string][]string {
fields := make(map[string][]string)
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, bytes.NewReader(headerStream.dataWritten.Bytes()))
frame, err := h2framer.ReadFrame()
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(BeAssignableToTypeOf(&http2.HeadersFrame{}))
hframe := frame.(*http2.HeadersFrame)
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
Expect(mhframe.StreamID).To(BeEquivalentTo(5))
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
Expect(err).ToNot(HaveOccurred())
for _, p := range mhframe.Fields {
fields[p.Name] = append(fields[p.Name], p.Value)
}
return fields
}
It("writes status", func() {
w.WriteHeader(http.StatusTeapot)
fields := decodeHeaderFields()
Expect(fields).To(HaveLen(1))
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
})
It("writes headers", func() {
w.Header().Add("content-length", "42")
w.WriteHeader(http.StatusTeapot)
fields := decodeHeaderFields()
Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"}))
})
It("writes multiple headers with the same name", func() {
const cookie1 = "test1=1; Max-Age=7200; path=/"
const cookie2 = "test2=2; Max-Age=7200; path=/"
w.Header().Add("set-cookie", cookie1)
w.Header().Add("set-cookie", cookie2)
w.WriteHeader(http.StatusTeapot)
fields := decodeHeaderFields()
Expect(fields).To(HaveKey("set-cookie"))
cookies := fields["set-cookie"]
Expect(cookies).To(ContainElement(cookie1))
Expect(cookies).To(ContainElement(cookie2))
})
It("writes data", func() {
n, err := w.Write([]byte("foobar"))
Expect(n).To(Equal(6))
Expect(err).ToNot(HaveOccurred())
// Should have written 200 on the header stream
fields := decodeHeaderFields()
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
// And foobar on the data stream
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
})
It("writes data after WriteHeader is called", func() {
w.WriteHeader(http.StatusTeapot)
n, err := w.Write([]byte("foobar"))
Expect(n).To(Equal(6))
Expect(err).ToNot(HaveOccurred())
// Should have written 418 on the header stream
fields := decodeHeaderFields()
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
// And foobar on the data stream
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
})
It("does not WriteHeader() twice", func() {
w.WriteHeader(200)
w.WriteHeader(500)
fields := decodeHeaderFields()
Expect(fields).To(HaveLen(1))
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
})
It("doesn't allow writes if the status code doesn't allow a body", func() {
w.WriteHeader(304)
n, err := w.Write([]byte("foobar"))
Expect(n).To(BeZero())
Expect(err).To(MatchError(http.ErrBodyNotAllowed))
Expect(dataStream.dataWritten.Bytes()).To(HaveLen(0))
})
})

View File

@ -0,0 +1,179 @@
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
quic "github.com/lucas-clemente/quic-go"
"golang.org/x/net/http/httpguts"
)
type roundTripCloser interface {
http.RoundTripper
io.Closer
}
// RoundTripper implements the http.RoundTripper interface
type RoundTripper struct {
mutex sync.Mutex
// DisableCompression, if true, prevents the Transport from
// requesting compression with an "Accept-Encoding: gzip"
// request header when the Request contains no existing
// Accept-Encoding value. If the Transport requests gzip on
// its own and gets a gzipped response, it's transparently
// decoded in the Response.Body. However, if the user
// explicitly requested gzip it is not automatically
// uncompressed.
DisableCompression bool
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
// QuicConfig is the quic.Config used for dialing new connections.
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddr will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
clients map[string]roundTripCloser
}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may
// create a new QUIC connection. If set true and
// no cached connection is available, RoundTrip
// will return ErrNoCachedConn.
OnlyCachedConn bool
}
var _ roundTripCloser = &RoundTripper{}
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
// RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL")
}
if req.URL.Host == "" {
closeRequestBody(req)
return nil, errors.New("quic: no Host in request URL")
}
if req.Header == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.Header")
}
if req.URL.Scheme == "https" {
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
}
}
}
} else {
closeRequestBody(req)
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
}
if req.Method != "" && !validMethod(req.Method) {
closeRequestBody(req)
return nil, fmt.Errorf("quic: invalid method %q", req.Method)
}
hostname := authorityAddr("https", hostnameFromRequest(req))
cl, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil {
return nil, err
}
return cl.RoundTrip(req)
}
// RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
r.clients = make(map[string]roundTripCloser)
}
client, ok := r.clients[hostname]
if !ok {
if onlyCached {
return nil, ErrNoCachedConn
}
client = newClient(
hostname,
r.TLSClientConfig,
&roundTripperOpts{DisableCompression: r.DisableCompression},
r.QuicConfig,
r.Dial,
)
r.clients[hostname] = client
}
return client, nil
}
// Close closes the QUIC connections that this RoundTripper has used
func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
return nil
}
func closeRequestBody(req *http.Request) {
if req.Body != nil {
req.Body.Close()
}
}
func validMethod(method string) bool {
/*
Method = "OPTIONS" ; Section 9.2
| "GET" ; Section 9.3
| "HEAD" ; Section 9.4
| "POST" ; Section 9.5
| "PUT" ; Section 9.6
| "DELETE" ; Section 9.7
| "TRACE" ; Section 9.8
| "CONNECT" ; Section 9.9
| extension-method
extension-method = token
token = 1*<any CHAR except CTLs or separators>
*/
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// copied from net/http/http.go
func isNotToken(r rune) bool {
return !httpguts.IsTokenRune(r)
}

View File

@ -0,0 +1,218 @@
package h2quic
import (
"bytes"
"crypto/tls"
"errors"
"io"
"net/http"
"time"
quic "github.com/lucas-clemente/quic-go"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockClient struct {
closed bool
}
func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{Request: req}, nil
}
func (m *mockClient) Close() error {
m.closed = true
return nil
}
var _ roundTripCloser = &mockClient{}
type mockBody struct {
reader bytes.Reader
readErr error
closeErr error
closed bool
}
func (m *mockBody) Read(p []byte) (int, error) {
if m.readErr != nil {
return 0, m.readErr
}
return m.reader.Read(p)
}
func (m *mockBody) SetData(data []byte) {
m.reader = *bytes.NewReader(data)
}
func (m *mockBody) Close() error {
m.closed = true
return m.closeErr
}
// make sure the mockBody can be used as a http.Request.Body
var _ io.ReadCloser = &mockBody{}
var _ = Describe("RoundTripper", func() {
var (
rt *RoundTripper
req1 *http.Request
)
BeforeEach(func() {
rt = &RoundTripper{}
var err error
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
})
Context("dialing hosts", func() {
origDialAddr := dialAddr
streamOpenErr := errors.New("error opening stream")
BeforeEach(func() {
origDialAddr = dialAddr
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
// return an error when trying to open a stream
// we don't want to test all the dial logic here, just that dialing happens at all
return &mockSession{streamOpenErr: streamOpenErr}, nil
}
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("creates new clients", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(streamOpenErr))
Expect(rt.clients).To(HaveLen(1))
})
It("uses the quic.Config, if provided", func() {
config := &quic.Config{HandshakeTimeout: time.Millisecond}
var receivedConfig *quic.Config
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
receivedConfig = config
return nil, errors.New("err")
}
rt.QuicConfig = config
rt.RoundTrip(req1)
Expect(receivedConfig).To(Equal(config))
})
It("uses the custom dialer, if provided", func() {
var dialed bool
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
dialed = true
return nil, errors.New("err")
}
rt.Dial = dialer
rt.RoundTrip(req1)
Expect(dialed).To(BeTrue())
})
It("reuses existing clients", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(streamOpenErr))
Expect(rt.clients).To(HaveLen(1))
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req2)
Expect(err).To(MatchError(streamOpenErr))
Expect(rt.clients).To(HaveLen(1))
})
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
Expect(err).To(MatchError(ErrNoCachedConn))
})
})
Context("validating request", func() {
It("rejects plain HTTP requests", func() {
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
req.Body = &mockBody{}
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError("quic: unsupported protocol scheme: http"))
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
It("rejects requests without a URL", func() {
req1.URL = nil
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: nil Request.URL"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
})
It("rejects request without a URL Host", func() {
req1.URL.Host = ""
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: no Host in request URL"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
})
It("doesn't try to close the body if the request doesn't have one", func() {
req1.URL = nil
Expect(req1.Body).To(BeNil())
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: nil Request.URL"))
})
It("rejects requests without a header", func() {
req1.Header = nil
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: nil Request.Header"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
})
It("rejects requests with invalid header name fields", func() {
req1.Header.Add("foobär", "value")
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: invalid http header field name \"foobär\""))
})
It("rejects requests with invalid header name values", func() {
req1.Header.Add("foo", string([]byte{0x7}))
_, err := rt.RoundTrip(req1)
Expect(err.Error()).To(ContainSubstring("quic: invalid http header field value"))
})
It("rejects requests with an invalid request method", func() {
req1.Method = "foobär"
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError("quic: invalid method \"foobär\""))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
})
})
Context("closing", func() {
It("closes", func() {
rt.clients = make(map[string]roundTripCloser)
cl := &mockClient{}
rt.clients["foo.bar"] = cl
err := rt.Close()
Expect(err).ToNot(HaveOccurred())
Expect(len(rt.clients)).To(BeZero())
Expect(cl.closed).To(BeTrue())
})
It("closes a RoundTripper that has never been used", func() {
Expect(len(rt.clients)).To(BeZero())
err := rt.Close()
Expect(err).ToNot(HaveOccurred())
Expect(len(rt.clients)).To(BeZero())
})
})
})

View File

@ -0,0 +1,402 @@
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
type streamCreator interface {
quic.Session
GetOrOpenStream(protocol.StreamID) (quic.Stream, error)
}
type remoteCloser interface {
CloseRemote(protocol.ByteCount)
}
// allows mocking of quic.Listen and quic.ListenAddr
var (
quicListen = quic.Listen
quicListenAddr = quic.ListenAddr
)
// Server is a HTTP2 server listening for QUIC connections.
type Server struct {
*http.Server
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
// If nil, it uses reasonable default values.
QuicConfig *quic.Config
// Private flag for demo, do not use
CloseAfterFirstRequest bool
port uint32 // used atomically
listenerMutex sync.Mutex
listener quic.Listener
closed bool
supportedVersionsAsString string
logger utils.Logger // will be set by Server.serveImpl()
}
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
func (s *Server) ListenAndServe() error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
return s.serveImpl(s.TLSConfig, nil)
}
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
var err error
certs := make([]tls.Certificate, 1)
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
// We currently only use the cert-related stuff from tls.Config,
// so we don't need to make a full copy.
config := &tls.Config{
Certificates: certs,
}
return s.serveImpl(config, nil)
}
// Serve an existing UDP connection.
func (s *Server) Serve(conn net.PacketConn) error {
return s.serveImpl(s.TLSConfig, conn)
}
func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
s.logger = utils.DefaultLogger.WithPrefix("server")
s.listenerMutex.Lock()
if s.closed {
s.listenerMutex.Unlock()
return errors.New("Server is already closed")
}
if s.listener != nil {
s.listenerMutex.Unlock()
return errors.New("ListenAndServe may only be called once")
}
var ln quic.Listener
var err error
if conn == nil {
ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
} else {
ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
}
if err != nil {
s.listenerMutex.Unlock()
return err
}
s.listener = ln
s.listenerMutex.Unlock()
for {
sess, err := ln.Accept()
if err != nil {
return err
}
go s.handleHeaderStream(sess.(streamCreator))
}
}
func (s *Server) handleHeaderStream(session streamCreator) {
stream, err := session.AcceptStream()
if err != nil {
session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
return
}
hpackDecoder := hpack.NewDecoder(4096, nil)
h2framer := http2.NewFramer(nil, stream)
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
for {
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
// QuicErrors must originate from stream.Read() returning an error.
// In this case, the session has already logged the error, so we don't
// need to log it again.
errorCode := qerr.InternalError
if qerr, ok := err.(*qerr.QuicError); !ok {
errorCode = qerr.ErrorCode
s.logger.Errorf("error handling h2 request: %s", err.Error())
}
session.CloseWithError(quic.ErrorCode(errorCode), err)
return
}
}
}
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
h2frame, err := h2framer.ReadFrame()
if err != nil {
return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
}
var h2headersFrame *http2.HeadersFrame
switch f := h2frame.(type) {
case *http2.PriorityFrame:
// ignore PRIORITY frames
s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
return nil
case *http2.HeadersFrame:
h2headersFrame = f
default:
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
}
if !h2headersFrame.HeadersEnded() {
return errors.New("http2 header continuation not implemented")
}
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
if err != nil {
s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
return err
}
req, err := requestFromHeaders(headers)
if err != nil {
return err
}
if s.logger.Debug() {
s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
} else {
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
}
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
if err != nil {
return err
}
// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request
if dataStream == nil {
return nil
}
// handleRequest should be as non-blocking as possible to minimize
// head-of-line blocking. Potentially blocking code is run in a separate
// goroutine, enabling handleRequest to return before the code is executed.
go func() {
streamEnded := h2headersFrame.StreamEnded()
if streamEnded {
dataStream.(remoteCloser).CloseRemote(0)
streamEnded = true
_, _ = dataStream.Read([]byte{0}) // read the eof
}
req = req.WithContext(dataStream.Context())
reqBody := newRequestBody(dataStream)
req.Body = reqBody
req.RemoteAddr = session.RemoteAddr().String()
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux
}
panicked := false
func() {
defer func() {
if p := recover(); p != nil {
// Copied from net/http/server.go
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
panicked = true
}
}()
handler.ServeHTTP(responseWriter, req)
}()
if panicked {
responseWriter.WriteHeader(500)
} else {
responseWriter.WriteHeader(200)
}
if responseWriter.dataStream != nil {
if !streamEnded && !reqBody.requestRead {
// in gQUIC, the error code doesn't matter, so just use 0 here
responseWriter.dataStream.CancelRead(0)
}
responseWriter.dataStream.Close()
}
if s.CloseAfterFirstRequest {
time.Sleep(100 * time.Millisecond)
session.Close()
}
}()
return nil
}
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) Close() error {
s.listenerMutex.Lock()
defer s.listenerMutex.Unlock()
s.closed = true
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
return err
}
return nil
}
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) CloseGracefully(timeout time.Duration) error {
// TODO: implement
return nil
}
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
func (s *Server) SetQuicHeaders(hdr http.Header) error {
port := atomic.LoadUint32(&s.port)
if port == 0 {
// Extract port from s.Server.Addr
_, portStr, err := net.SplitHostPort(s.Server.Addr)
if err != nil {
return err
}
portInt, err := net.LookupPort("tcp", portStr)
if err != nil {
return err
}
port = uint32(portInt)
atomic.StoreUint32(&s.port, port)
}
if s.supportedVersionsAsString == "" {
var versions []string
for _, v := range protocol.SupportedVersions {
versions = append(versions, v.ToAltSvc())
}
s.supportedVersionsAsString = strings.Join(versions, ",")
}
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
return nil
}
// ListenAndServeQUIC listens on the UDP network address addr and calls the
// handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is
// used when handler is nil.
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
server := &Server{
Server: &http.Server{
Addr: addr,
Handler: handler,
},
}
return server.ListenAndServeTLS(certFile, keyFile)
}
// ListenAndServe listens on the given network address for both, TLS and QUIC
// connetions in parallel. It returns if one of the two returns an error.
// http.DefaultServeMux is used when handler is nil.
// The correct Alt-Svc headers for QUIC are set.
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
// Load certs
var err error
certs := make([]tls.Certificate, 1)
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
// We currently only use the cert-related stuff from tls.Config,
// so we don't need to make a full copy.
config := &tls.Config{
Certificates: certs,
}
// Open the listeners
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
defer udpConn.Close()
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return err
}
tcpConn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return err
}
defer tcpConn.Close()
tlsConn := tls.NewListener(tcpConn, config)
defer tlsConn.Close()
// Start the servers
httpServer := &http.Server{
Addr: addr,
TLSConfig: config,
}
quicServer := &Server{
Server: httpServer,
}
if handler == nil {
handler = http.DefaultServeMux
}
httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
quicServer.SetQuicHeaders(w.Header())
handler.ServeHTTP(w, r)
})
hErr := make(chan error)
qErr := make(chan error)
go func() {
hErr <- httpServer.Serve(tlsConn)
}()
go func() {
qErr <- quicServer.Serve(udpConn)
}()
select {
case err := <-hErr:
quicServer.Close()
return err
case err := <-qErr:
// Cannot close the HTTP server or wait for requests to complete properly :/
return err
}
}

View File

@ -0,0 +1,536 @@
package h2quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockSession struct {
closed bool
closedWithError error
dataStream quic.Stream
streamToAccept quic.Stream
streamsToOpen []quic.Stream
blockOpenStreamSync bool
blockOpenStreamChan chan struct{} // close this chan (or call Close) to make OpenStreamSync return
streamOpenErr error
ctx context.Context
ctxCancel context.CancelFunc
}
func newMockSession() *mockSession {
return &mockSession{blockOpenStreamChan: make(chan struct{})}
}
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
return s.dataStream, nil
}
func (s *mockSession) AcceptStream() (quic.Stream, error) { return s.streamToAccept, nil }
func (s *mockSession) OpenStream() (quic.Stream, error) {
if s.streamOpenErr != nil {
return nil, s.streamOpenErr
}
str := s.streamsToOpen[0]
s.streamsToOpen = s.streamsToOpen[1:]
return str, nil
}
func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
if s.blockOpenStreamSync {
<-s.blockOpenStreamChan
}
return s.OpenStream()
}
func (s *mockSession) Close() error {
s.ctxCancel()
if !s.closed {
close(s.blockOpenStreamChan)
}
s.closed = true
return nil
}
func (s *mockSession) CloseWithError(_ quic.ErrorCode, e error) error {
s.closedWithError = e
return s.Close()
}
func (s *mockSession) LocalAddr() net.Addr {
panic("not implemented")
}
func (s *mockSession) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
}
func (s *mockSession) Context() context.Context {
return s.ctx
}
func (s *mockSession) ConnectionState() quic.ConnectionState { panic("not implemented") }
func (s *mockSession) AcceptUniStream() (quic.ReceiveStream, error) { panic("not implemented") }
func (s *mockSession) OpenUniStream() (quic.SendStream, error) { panic("not implemented") }
func (s *mockSession) OpenUniStreamSync() (quic.SendStream, error) { panic("not implemented") }
var _ = Describe("H2 server", func() {
var (
s *Server
session *mockSession
dataStream *mockStream
origQuicListenAddr = quicListenAddr
)
BeforeEach(func() {
s = &Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
logger: utils.DefaultLogger,
}
dataStream = newMockStream(0)
close(dataStream.unblockRead)
session = newMockSession()
session.dataStream = dataStream
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
origQuicListenAddr = quicListenAddr
})
AfterEach(func() {
quicListenAddr = origQuicListenAddr
})
Context("handling requests", func() {
var (
h2framer *http2.Framer
hpackDecoder *hpack.Decoder
headerStream *mockStream
)
BeforeEach(func() {
headerStream = &mockStream{}
hpackDecoder = hpack.NewDecoder(4096, nil)
h2framer = http2.NewFramer(nil, headerStream)
})
It("handles a sample GET request", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
Expect(r.Host).To(Equal("www.example.com"))
Expect(r.RemoteAddr).To(Equal("127.0.0.1:42"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeTrue())
Expect(dataStream.reset).To(BeFalse())
})
It("returns 200 with an empty handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() []byte {
return headerStream.dataWritten.Bytes()
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88})) // 0x88 is 200
})
It("correctly handles a panicking handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("foobar")
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() []byte {
return headerStream.dataWritten.Bytes()
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x8e})) // 0x82 is 500
})
It("resets the dataStream when client sends a body in GET request", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeFalse())
})
It("resets the dataStream when the body of POST request is not read", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
Expect(r.Method).To(Equal("POST"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
Expect(handlerCalled).To(BeTrue())
})
It("handles a request for which the client immediately resets the data stream", func() {
session.dataStream = nil
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
})
It("resets the dataStream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = struct {
io.Reader
io.Closer
}{}
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
Expect(handlerCalled).To(BeTrue())
})
It("closes the dataStream if the body of POST request was read", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
Expect(r.Method).To(Equal("POST"))
handlerCalled = true
// read the request body
b := make([]byte, 1000)
n, _ := r.Body.Read(b)
Expect(n).ToNot(BeZero())
})
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
dataStream.dataToRead.Write([]byte("foo=bar"))
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.reset).To(BeFalse())
})
It("ignores PRIORITY frames", func() {
handlerCalled := make(chan struct{})
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(handlerCalled)
})
buf := &bytes.Buffer{}
framer := http2.NewFramer(buf, nil)
err := framer.WritePriority(10, http2.PriorityParam{Weight: 42})
Expect(err).ToNot(HaveOccurred())
Expect(buf.Bytes()).ToNot(BeEmpty())
headerStream.dataToRead.Write(buf.Bytes())
err = s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).ToNot(HaveOccurred())
Consistently(handlerCalled).ShouldNot(BeClosed())
Expect(dataStream.reset).To(BeFalse())
Expect(dataStream.closed).To(BeFalse())
})
It("errors when non-header frames are received", func() {
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x06, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5,
'f', 'o', 'o', 'b', 'a', 'r',
})
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).To(MatchError("InvalidHeadersStreamData: expected a header frame"))
})
It("Cancels the request context when the datstream is closed", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
err := r.Context().Err()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("context canceled"))
handlerCalled = true
})
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
dataStream.Close()
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
Expect(dataStream.remoteClosed).To(BeTrue())
Expect(dataStream.reset).To(BeFalse())
})
})
It("handles the header stream", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
handlerCalled = true
})
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
session.streamToAccept = headerStream
go s.handleHeaderStream(session)
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
})
It("closes the connection if it encounters an error on the header stream", func() {
var handlerCalled bool
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
})
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
session.streamToAccept = headerStream
go s.handleHeaderStream(session)
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
Eventually(func() bool { return session.closed }).Should(BeTrue())
Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
})
It("supports closing after first request", func() {
s.CloseAfterFirstRequest = true
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
session.streamToAccept = headerStream
Expect(session.closed).To(BeFalse())
go s.handleHeaderStream(session)
Eventually(func() bool { return session.closed }).Should(BeTrue())
})
It("uses the default handler as fallback", func() {
var handlerCalled bool
http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.Host).To(Equal("www.example.com"))
handlerCalled = true
}))
headerStream := &mockStream{id: 3}
headerStream.dataToRead.Write([]byte{
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
})
session.streamToAccept = headerStream
go s.handleHeaderStream(session)
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
})
Context("setting http headers", func() {
var expected http.Header
getExpectedHeader := func(versions []protocol.VersionNumber) http.Header {
var versionsAsString []string
for _, v := range versions {
versionsAsString = append(versionsAsString, v.ToAltSvc())
}
return http.Header{
"Alt-Svc": {fmt.Sprintf(`quic=":443"; ma=2592000; v="%s"`, strings.Join(versionsAsString, ","))},
}
}
BeforeEach(func() {
Expect(getExpectedHeader([]protocol.VersionNumber{99, 90, 9})).To(Equal(http.Header{"Alt-Svc": {`quic=":443"; ma=2592000; v="99,90,9"`}}))
expected = getExpectedHeader(protocol.SupportedVersions)
})
It("sets proper headers with numeric port", func() {
s.Server.Addr = ":443"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
It("sets proper headers with full addr", func() {
s.Server.Addr = "127.0.0.1:443"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
It("sets proper headers with string port", func() {
s.Server.Addr = ":https"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
It("works multiple times", func() {
s.Server.Addr = ":https"
hdr := http.Header{}
err := s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
hdr = http.Header{}
err = s.SetQuicHeaders(hdr)
Expect(err).NotTo(HaveOccurred())
Expect(hdr).To(Equal(expected))
})
})
It("should error when ListenAndServe is called with s.Server nil", func() {
err := (&Server{}).ListenAndServe()
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
})
It("should error when ListenAndServeTLS is called with s.Server nil", func() {
err := (&Server{}).ListenAndServeTLS(testdata.GetCertificatePaths())
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
})
It("should nop-Close() when s.server is nil", func() {
err := (&Server{}).Close()
Expect(err).NotTo(HaveOccurred())
})
It("errors when ListenAndServer is called after Close", func() {
serv := &Server{Server: &http.Server{}}
Expect(serv.Close()).To(Succeed())
err := serv.ListenAndServe()
Expect(err).To(MatchError("Server is already closed"))
})
Context("ListenAndServe", func() {
BeforeEach(func() {
s.Server.Addr = "localhost:0"
})
AfterEach(func() {
Expect(s.Close()).To(Succeed())
})
It("may only be called once", func() {
cErr := make(chan error)
for i := 0; i < 2; i++ {
go func() {
defer GinkgoRecover()
err := s.ListenAndServe()
if err != nil {
cErr <- err
}
}()
}
err := <-cErr
Expect(err).To(MatchError("ListenAndServe may only be called once"))
err = s.Close()
Expect(err).NotTo(HaveOccurred())
}, 0.5)
It("uses the quic.Config to start the quic server", func() {
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
var receivedConf *quic.Config
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
receivedConf = config
return nil, errors.New("listen err")
}
s.QuicConfig = conf
go s.ListenAndServe()
Eventually(func() *quic.Config { return receivedConf }).Should(Equal(conf))
})
})
Context("ListenAndServeTLS", func() {
BeforeEach(func() {
s.Server.Addr = "localhost:0"
})
AfterEach(func() {
err := s.Close()
Expect(err).NotTo(HaveOccurred())
})
It("may only be called once", func() {
cErr := make(chan error)
for i := 0; i < 2; i++ {
go func() {
defer GinkgoRecover()
err := s.ListenAndServeTLS(testdata.GetCertificatePaths())
if err != nil {
cErr <- err
}
}()
}
err := <-cErr
Expect(err).To(MatchError("ListenAndServe may only be called once"))
err = s.Close()
Expect(err).NotTo(HaveOccurred())
}, 0.5)
})
It("closes gracefully", func() {
err := s.CloseGracefully(0)
Expect(err).NotTo(HaveOccurred())
})
It("errors when listening fails", func() {
testErr := errors.New("listen error")
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
return nil, testErr
}
fullpem, privkey := testdata.GetCertificatePaths()
err := ListenAndServeQUIC("", fullpem, privkey, nil)
Expect(err).To(MatchError(testErr))
})
})

View File

@ -0,0 +1,210 @@
package chrome_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"sync/atomic"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gexec"
"testing"
)
const (
dataLen = 500 * 1024 // 500 KB
dataLongLen = 50 * 1024 * 1024 // 50 MB
)
var (
nFilesUploaded int32 // should be used atomically
doneCalled utils.AtomicBool
)
func TestChrome(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Chrome Suite")
}
func init() {
// Requires the len & num GET parameters, e.g. /uploadtest?len=100&num=1
http.HandleFunc("/uploadtest", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
response := uploadHTML
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
_, err := io.WriteString(w, response)
Expect(err).NotTo(HaveOccurred())
})
// Requires the len & num GET parameters, e.g. /downloadtest?len=100&num=1
http.HandleFunc("/downloadtest", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
response := downloadHTML
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
_, err := io.WriteString(w, response)
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/uploadhandler", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
l, err := strconv.Atoi(r.URL.Query().Get("len"))
Expect(err).NotTo(HaveOccurred())
defer r.Body.Close()
actual, err := ioutil.ReadAll(r.Body)
Expect(err).NotTo(HaveOccurred())
Expect(bytes.Equal(actual, testserver.GeneratePRData(l))).To(BeTrue())
atomic.AddInt32(&nFilesUploaded, 1)
})
http.HandleFunc("/done", func(w http.ResponseWriter, r *http.Request) {
doneCalled.Set(true)
})
}
var _ = AfterEach(func() {
testserver.StopQuicServer()
atomic.StoreInt32(&nFilesUploaded, 0)
doneCalled.Set(false)
})
func getChromePath() string {
if runtime.GOOS == "darwin" {
return "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
}
if path, err := exec.LookPath("google-chrome"); err == nil {
return path
}
if path, err := exec.LookPath("chromium-browser"); err == nil {
return path
}
Fail("No Chrome executable found.")
return ""
}
func chromeTest(version protocol.VersionNumber, url string, blockUntilDone func()) {
userDataDir, err := ioutil.TempDir("", "quic-go-test-chrome-dir")
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(userDataDir)
path := getChromePath()
args := []string{
"--disable-gpu",
"--no-first-run=true",
"--no-default-browser-check=true",
"--user-data-dir=" + userDataDir,
"--enable-quic=true",
"--no-proxy-server=true",
"--no-sandbox",
"--origin-to-force-quic-on=quic.clemente.io:443",
fmt.Sprintf(`--host-resolver-rules=MAP quic.clemente.io:443 127.0.0.1:%s`, testserver.Port()),
fmt.Sprintf("--quic-version=QUIC_VERSION_%s", version.ToAltSvc()),
url,
}
utils.DefaultLogger.Infof("Running chrome: %s '%s'", getChromePath(), strings.Join(args, "' '"))
command := exec.Command(path, args...)
session, err := gexec.Start(command, nil, nil)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
blockUntilDone()
}
func waitForDone() {
Eventually(func() bool { return doneCalled.Get() }, 60).Should(BeTrue())
}
func waitForNUploaded(expected int) func() {
return func() {
Eventually(func() int32 {
return atomic.LoadInt32(&nFilesUploaded)
}, 60).Should(BeEquivalentTo(expected))
}
}
const commonJS = `
var buf = new ArrayBuffer(LENGTH);
var prng = new Uint8Array(buf);
var seed = 1;
for (var i = 0; i < LENGTH; i++) {
// https://en.wikipedia.org/wiki/Lehmer_random_number_generator
seed = seed * 48271 % 2147483647;
prng[i] = seed;
}
`
const uploadHTML = `
<html>
<body>
<script>
console.log("Running DL test...");
` + commonJS + `
for (var i = 0; i < NUM; i++) {
var req = new XMLHttpRequest();
req.open("POST", "/uploadhandler?len=" + LENGTH, true);
req.send(buf);
}
</script>
</body>
</html>
`
const downloadHTML = `
<html>
<body>
<script>
console.log("Running DL test...");
` + commonJS + `
function verify(data) {
if (data.length !== LENGTH) return false;
for (var i = 0; i < LENGTH; i++) {
if (data[i] !== prng[i]) return false;
}
return true;
}
var nOK = 0;
for (var i = 0; i < NUM; i++) {
let req = new XMLHttpRequest();
req.responseType = "arraybuffer";
req.open("POST", "/prdata?len=" + LENGTH, true);
req.onreadystatechange = function () {
if (req.readyState === XMLHttpRequest.DONE && req.status === 200) {
if (verify(new Uint8Array(req.response))) {
nOK++;
if (nOK === NUM) {
console.log("Done :)");
var reqDone = new XMLHttpRequest();
reqDone.open("GET", "/done");
reqDone.send();
}
}
}
};
req.send();
}
</script>
</body>
</html>
`

View File

@ -0,0 +1,76 @@
package chrome_test
import (
"fmt"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
)
var _ = Describe("Chrome tests", func() {
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
// TODO: activate Chrome integration tests with gQUIC 44
if version == protocol.Version44 {
continue
}
Context(fmt.Sprintf("with version %s", version), func() {
JustBeforeEach(func() {
testserver.StartQuicServer([]protocol.VersionNumber{version})
})
It("downloads a small file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLen),
waitForDone,
)
})
It("downloads a large file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLongLen),
waitForDone,
)
})
It("loads a large number of files", func() {
chromeTest(
version,
"https://quic.clemente.io/downloadtest?num=4&len=100",
waitForDone,
)
})
It("uploads a small file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLen),
waitForNUploaded(1),
)
})
It("uploads a large file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLongLen),
waitForNUploaded(1),
)
})
It("uploads many small files", func() {
num := protocol.DefaultMaxIncomingStreams + 20
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=%d&len=%d", num, dataLen),
waitForNUploaded(num),
)
})
})
}
})

View File

@ -0,0 +1,137 @@
package gquic_test
import (
"bytes"
"fmt"
mrand "math/rand"
"os/exec"
"strconv"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gbytes"
. "github.com/onsi/gomega/gexec"
)
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
var _ = Describe("Drop tests", func() {
var proxy *quicproxy.QuicProxy
startProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) {
var err error
proxy, err = quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
RemoteAddr: "localhost:" + testserver.Port(),
DropPacket: dropCallback,
})
Expect(err).ToNot(HaveOccurred())
}
downloadFile := func(version protocol.VersionNumber) {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}
downloadHello := func(version protocol.VersionNumber) {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/hello",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(session.Out).To(Say(":status 200"))
Expect(session.Out).To(Say("body: Hello, World!\n"))
}
deterministicDropper := func(p, interval, dropInARow uint64) bool {
return (p % interval) < dropInARow
}
stochasticDropper := func(freq int) bool {
return mrand.Int63n(int64(freq)) == 0
}
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
})
for _, v := range protocol.SupportedVersions {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
Context("during the crypto handshake", func() {
for _, d := range directions {
direction := d
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p == 1 && d.Is(direction)
}, version)
downloadHello(version)
})
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p == 2 && d.Is(direction)
}, version)
downloadHello(version)
})
It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return d.Is(direction) && stochasticDropper(5)
}, version)
downloadHello(version)
})
}
})
Context("after the crypto handshake", func() {
for _, d := range directions {
direction := d
It(fmt.Sprintf("downloads a file when every 5th packet is dropped in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p >= 10 && d.Is(direction) && deterministicDropper(p, 5, 1)
}, version)
downloadFile(version)
})
It(fmt.Sprintf("downloads a file when 1/5th of all packet are dropped randomly in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p >= 10 && d.Is(direction) && stochasticDropper(5)
}, version)
downloadFile(version)
})
It(fmt.Sprintf("downloads a file when 10 packets every 100 packet are dropped in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p >= 10 && d.Is(direction) && deterministicDropper(p, 100, 10)
}, version)
downloadFile(version)
})
}
})
})
}
})

View File

@ -0,0 +1,45 @@
package gquic_test
import (
"fmt"
"math/rand"
"path/filepath"
"runtime"
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
var (
clientPath string
serverPath string
)
func TestIntegration(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "GQuic Tests Suite")
}
var _ = BeforeSuite(func() {
rand.Seed(GinkgoRandomSeed())
})
var _ = JustBeforeEach(func() {
testserver.StartQuicServer(nil)
})
var _ = AfterEach(testserver.StopQuicServer)
func init() {
_, thisfile, _, ok := runtime.Caller(0)
if !ok {
panic("Failed to get current path")
}
clientPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/client-%s-debug", runtime.GOOS))
serverPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/server-%s-debug", runtime.GOOS))
}

View File

@ -0,0 +1,98 @@
package gquic_test
import (
"bytes"
"fmt"
"os/exec"
"sync"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
_ "github.com/lucas-clemente/quic-clients" // download clients
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gbytes"
. "github.com/onsi/gomega/gexec"
)
var _ = Describe("Integration tests", func() {
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with QUIC version %s", version), func() {
It("gets a simple file", func() {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"https://quic.clemente.io/hello",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 5).Should(Exit(0))
Expect(session.Out).To(Say(":status 200"))
Expect(session.Out).To(Say("body: Hello, World!\n"))
})
It("posts and reads a body", func() {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"--body=foo",
"https://quic.clemente.io/echo",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 5).Should(Exit(0))
Expect(session.Out).To(Say(":status 200"))
Expect(session.Out).To(Say("body: foo\n"))
})
It("gets a file", func() {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 10).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
})
It("gets many copies of a file in parallel", func() {
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}()
}
wg.Wait()
})
})
}
})

View File

@ -0,0 +1,95 @@
package gquic_test
import (
"bytes"
"fmt"
"math/rand"
"os/exec"
"strconv"
"time"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gexec"
)
// get a random duration between min and max
func getRandomDuration(min, max time.Duration) time.Duration {
return min + time.Duration(rand.Int63n(int64(max-min)))
}
var _ = Describe("Random Duration Generator", func() {
It("gets a random RTT", func() {
var min time.Duration = time.Hour
var max time.Duration
var sum time.Duration
rep := 10000
for i := 0; i < rep; i++ {
val := getRandomDuration(100*time.Millisecond, 500*time.Millisecond)
sum += val
if val < min {
min = val
}
if val > max {
max = val
}
}
avg := sum / time.Duration(rep)
Expect(avg).To(BeNumerically("~", 300*time.Millisecond, 5*time.Millisecond))
Expect(min).To(BeNumerically(">=", 100*time.Millisecond))
Expect(min).To(BeNumerically("<", 105*time.Millisecond))
Expect(max).To(BeNumerically(">", 495*time.Millisecond))
Expect(max).To(BeNumerically("<=", 500*time.Millisecond))
})
})
var _ = Describe("Random RTT", func() {
var proxy *quicproxy.QuicProxy
runRTTTest := func(minRtt, maxRtt time.Duration, version protocol.VersionNumber) {
var err error
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
RemoteAddr: "localhost:" + testserver.Port(),
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
return getRandomDuration(minRtt, maxRtt)
},
})
Expect(err).ToNot(HaveOccurred())
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}
AfterEach(func() {
err := proxy.Close()
Expect(err).ToNot(HaveOccurred())
time.Sleep(time.Millisecond)
})
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with QUIC version %s", version), func() {
It("gets a file a random RTT between 10ms and 30ms", func() {
runRTTTest(10*time.Millisecond, 30*time.Millisecond, version)
})
})
}
})

View File

@ -0,0 +1,66 @@
package gquic_test
import (
"bytes"
"fmt"
"os/exec"
"strconv"
"time"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gexec"
)
var _ = Describe("non-zero RTT", func() {
var proxy *quicproxy.QuicProxy
runRTTTest := func(rtt time.Duration, version protocol.VersionNumber) {
var err error
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
RemoteAddr: "localhost:" + testserver.Port(),
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
return rtt / 2
},
})
Expect(err).ToNot(HaveOccurred())
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}
AfterEach(func() {
err := proxy.Close()
Expect(err).ToNot(HaveOccurred())
time.Sleep(time.Millisecond)
})
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with QUIC version %s", version), func() {
roundTrips := [...]int{10, 50, 100, 200}
for _, rtt := range roundTrips {
It(fmt.Sprintf("gets a 500kB file with %dms RTT", rtt), func() {
runRTTTest(time.Duration(rtt)*time.Millisecond, version)
})
}
})
}
})

View File

@ -0,0 +1,218 @@
package gquic_test
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
mrand "math/rand"
"net/http"
"os"
"os/exec"
"path/filepath"
"strconv"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
. "github.com/onsi/gomega/gexec"
)
var _ = Describe("Server tests", func() {
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
var (
serverPort string
tmpDir string
session *Session
client *http.Client
)
generateCA := func() (*rsa.PrivateKey, *x509.Certificate) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred())
templateRoot := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, templateRoot, templateRoot, &key.PublicKey, key)
Expect(err).ToNot(HaveOccurred())
cert, err := x509.ParseCertificate(certDER)
Expect(err).ToNot(HaveOccurred())
return key, cert
}
// prepare the file such that it can be by the quic_server
// some HTTP headers neeed to be prepended, see https://www.chromium.org/quic/playing-with-quic
createDownloadFile := func(filename string, data []byte) {
dataDir := filepath.Join(tmpDir, "quic.clemente.io")
err := os.Mkdir(dataDir, 0777)
Expect(err).ToNot(HaveOccurred())
f, err := os.Create(filepath.Join(dataDir, filename))
Expect(err).ToNot(HaveOccurred())
defer f.Close()
_, err = f.Write([]byte("HTTP/1.1 200 OK\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write([]byte("Content-Type: text/html\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write([]byte("X-Original-Url: https://quic.clemente.io:" + serverPort + "/" + filename + "\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write([]byte("Content-Length: " + strconv.Itoa(len(data)) + "\n\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write(data)
Expect(err).ToNot(HaveOccurred())
}
// download files must be create *before* the quic_server is started
// the quic_server reads its data dir on startup, and only serves those files that were already present then
startServer := func(version protocol.VersionNumber) {
defer GinkgoRecover()
var err error
command := exec.Command(
serverPath,
"--quic_response_cache_dir="+filepath.Join(tmpDir, "quic.clemente.io"),
"--key_file="+filepath.Join(tmpDir, "key.pkcs8"),
"--certificate_file="+filepath.Join(tmpDir, "cert.pem"),
"--quic-version="+strconv.Itoa(int(version)),
"--port="+serverPort,
)
session, err = Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
}
stopServer := func() {
session.Kill()
}
BeforeEach(func() {
serverPort = strconv.Itoa(20000 + int(mrand.Int31n(10000)))
var err error
tmpDir, err = ioutil.TempDir("", "quic-server-certs")
Expect(err).ToNot(HaveOccurred())
// generate an RSA key pair for the server
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred())
// save the private key in PKCS8 format to disk (required by quic_server)
pkcs8key, err := asn1.Marshal(struct { // copied from the x509 package
Version int
Algo pkix.AlgorithmIdentifier
PrivateKey []byte
}{
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
Algo: pkix.AlgorithmIdentifier{
Algorithm: asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1},
Parameters: asn1.RawValue{Tag: 5},
},
})
Expect(err).ToNot(HaveOccurred())
f, err := os.Create(filepath.Join(tmpDir, "key.pkcs8"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write(pkcs8key)
Expect(err).ToNot(HaveOccurred())
f.Close()
// generate a Certificate Authority
// this CA is used to sign the server's key
// it is set as a valid CA in the QUIC client
rootKey, CACert := generateCA()
// generate the server certificate
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-30 * time.Minute),
NotAfter: time.Now().Add(30 * time.Minute),
Subject: pkix.Name{CommonName: "quic.clemente.io"},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, CACert, &key.PublicKey, rootKey)
Expect(err).ToNot(HaveOccurred())
// save the certificate to disk
certOut, err := os.Create(filepath.Join(tmpDir, "cert.pem"))
Expect(err).ToNot(HaveOccurred())
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
certOut.Close()
// prepare the h2quic.client
certPool := x509.NewCertPool()
certPool.AddCert(CACert)
client = &http.Client{
Transport: &h2quic.RoundTripper{
TLSClientConfig: &tls.Config{RootCAs: certPool},
QuicConfig: &quic.Config{
Versions: []protocol.VersionNumber{version},
},
},
}
})
AfterEach(func() {
Expect(tmpDir).ToNot(BeEmpty())
err := os.RemoveAll(tmpDir)
Expect(err).ToNot(HaveOccurred())
tmpDir = ""
})
Context(fmt.Sprintf("with QUIC version %s", version), func() {
It("downloads a hello", func() {
data := []byte("Hello world!\n")
createDownloadFile("hello", data)
startServer(version)
defer stopServer()
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/hello")
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(data))
})
It("downloads a small file", func() {
createDownloadFile("file.dat", testserver.PRData)
startServer(version)
defer stopServer()
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRData))
})
It("downloads a large file", func() {
createDownloadFile("file.dat", testserver.PRDataLong)
startServer(version)
defer stopServer()
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 20*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRDataLong))
})
})
}
})

View File

@ -0,0 +1,97 @@
package self_test
import (
"bytes"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)
var _ = Describe("Client tests", func() {
var client *http.Client
// also run some tests with the TLS handshake
versions := append(protocol.SupportedVersions, protocol.VersionTLS)
BeforeEach(func() {
err := os.Setenv("HOSTALIASES", "quic.clemente.io 127.0.0.1")
Expect(err).ToNot(HaveOccurred())
addr, err := net.ResolveUDPAddr("udp4", "quic.clemente.io:0")
Expect(err).ToNot(HaveOccurred())
if addr.String() != "127.0.0.1:0" {
Fail("quic.clemente.io does not resolve to 127.0.0.1. Consider adding it to /etc/hosts.")
}
testserver.StartQuicServer(versions)
})
AfterEach(func() {
testserver.StopQuicServer()
})
for _, v := range versions {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
BeforeEach(func() {
client = &http.Client{
Transport: &h2quic.RoundTripper{
QuicConfig: &quic.Config{
Versions: []protocol.VersionNumber{version},
},
},
}
})
It("downloads a hello", func() {
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/hello")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("Hello, World!\n"))
})
It("downloads a small file", func() {
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/prdata")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRData))
})
It("downloads a large file", func() {
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/prdatalong")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 20*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRDataLong))
})
It("uploads a file", func() {
resp, err := client.Post(
"https://quic.clemente.io:"+testserver.Port()+"/echo",
"text/plain",
bytes.NewReader(testserver.PRData),
)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(bytes.Equal(body, testserver.PRData)).To(BeTrue())
})
})
}
})

View File

@ -0,0 +1,101 @@
package self_test
import (
"crypto/tls"
"fmt"
"io/ioutil"
"math/rand"
"net"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Connection ID lengths tests", func() {
randomConnIDLen := func() int {
return 4 + int(rand.Int31n(15))
}
runServer := func(conf *quic.Config) quic.Listener {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
ln, err := quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), conf)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
for {
sess, err := ln.Accept()
if err != nil {
return
}
go func() {
defer GinkgoRecover()
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
defer str.Close()
_, err = str.Write(testserver.PRData)
Expect(err).ToNot(HaveOccurred())
}()
}
}()
return ln
}
runClient := func(addr net.Addr, conf *quic.Config) {
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
cl, err := quic.DialAddr(
fmt.Sprintf("quic.clemente.io:%d", addr.(*net.UDPAddr).Port),
&tls.Config{InsecureSkipVerify: true},
conf,
)
Expect(err).ToNot(HaveOccurred())
defer cl.Close()
str, err := cl.AcceptStream()
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testserver.PRData))
}
Context("IETF QUIC", func() {
It("downloads a file using a 0-byte connection ID for the client", func() {
serverConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
clientConf := &quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
It("downloads a file when both client and server use a random connection ID length", func() {
serverConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
clientConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
})
Context("gQUIC", func() {
It("downloads a file using a 0-byte connection ID for the client", func() {
ln := runServer(&quic.Config{})
defer ln.Close()
runClient(ln.Addr(), &quic.Config{RequestConnectionIDOmission: true})
})
})
})

View File

@ -0,0 +1,189 @@
package self_test
import (
"fmt"
mrand "math/rand"
"net"
"time"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
)
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
type applicationProtocol struct {
name string
run func(protocol.VersionNumber)
}
var _ = Describe("Handshake drop tests", func() {
var (
proxy *quicproxy.QuicProxy
ln quic.Listener
)
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) {
var err error
ln, err = quic.ListenAddr(
"localhost:0",
testdata.GetTLSConfig(),
&quic.Config{
Versions: []protocol.VersionNumber{version},
},
)
Expect(err).ToNot(HaveOccurred())
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err = quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DropPacket: dropCallback,
},
)
Expect(err).ToNot(HaveOccurred())
}
stochasticDropper := func(freq int) bool {
return mrand.Int63n(int64(freq)) == 0
}
clientSpeaksFirst := &applicationProtocol{
name: "client speaks first",
run: func(version protocol.VersionNumber) {
serverSessionChan := make(chan quic.Session)
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
defer sess.Close()
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal("foobar"))
serverSessionChan <- sess
}()
sess, err := quic.DialAddr(
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
nil,
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
var serverSession quic.Session
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
sess.Close()
serverSession.Close()
},
}
serverSpeaksFirst := &applicationProtocol{
name: "server speaks first",
run: func(version protocol.VersionNumber) {
serverSessionChan := make(chan quic.Session)
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
serverSessionChan <- sess
}()
sess, err := quic.DialAddr(
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
nil,
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
b := make([]byte, 6)
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal("foobar"))
var serverSession quic.Session
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
sess.Close()
serverSession.Close()
},
}
nobodySpeaks := &applicationProtocol{
name: "nobody speaks",
run: func(version protocol.VersionNumber) {
serverSessionChan := make(chan quic.Session)
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
serverSessionChan <- sess
}()
sess, err := quic.DialAddr(
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
nil,
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
var serverSession quic.Session
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
// both server and client accepted a session. Close now.
sess.Close()
serverSession.Close()
},
}
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
})
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
for _, d := range directions {
direction := d
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
app := a
Context(app.name, func() {
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() {
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
return p == 1 && d.Is(direction)
}, version)
app.run(version)
})
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() {
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
return p == 2 && d.Is(direction)
}, version)
app.run(version)
})
It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() {
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
return d.Is(direction) && stochasticDropper(5)
}, version)
app.run(version)
})
})
}
}
})
}
})

View File

@ -0,0 +1,213 @@
package self_test
import (
"crypto/tls"
"net"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Handshake RTT tests", func() {
var (
proxy *quicproxy.QuicProxy
server quic.Listener
serverConfig *quic.Config
testStartedAt time.Time
acceptStopped chan struct{}
)
rtt := 400 * time.Millisecond
BeforeEach(func() {
acceptStopped = make(chan struct{})
serverConfig = &quic.Config{}
})
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
Expect(server.Close()).To(Succeed())
<-acceptStopped
})
runServerAndProxy := func() {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
// start the proxy
proxy, err = quicproxy.NewQuicProxy("localhost:0", protocol.VersionWhatever, &quicproxy.Opts{
RemoteAddr: server.Addr().String(),
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration { return rtt / 2 },
})
Expect(err).ToNot(HaveOccurred())
testStartedAt = time.Now()
go func() {
defer GinkgoRecover()
defer close(acceptStopped)
for {
_, err := server.Accept()
if err != nil {
return
}
}
}()
}
expectDurationInRTTs := func(num int) {
testDuration := time.Since(testStartedAt)
rtts := float32(testDuration) / float32(rtt)
Expect(rtts).To(SatisfyAll(
BeNumerically(">=", num),
BeNumerically("<", num+1),
))
}
It("fails when there's no matching version, after 1 RTT", func() {
if len(protocol.SupportedVersions) == 1 {
Skip("Test requires at least 2 supported versions.")
}
serverConfig.Versions = protocol.SupportedVersions[:1]
runServerAndProxy()
clientConfig := &quic.Config{
Versions: protocol.SupportedVersions[1:2],
}
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
Expect(err).To(HaveOccurred())
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
expectDurationInRTTs(1)
})
Context("gQUIC", func() {
// 1 RTT for verifying the source address
// 1 RTT to become secure
// 1 RTT to become forward-secure
It("is forward-secure after 3 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
It("does version negotiation in 1 RTT, IETF QUIC => gQUIC", func() {
clientConfig := &quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS, protocol.SupportedVersions[0]},
}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
&tls.Config{InsecureSkipVerify: true},
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(4)
})
It("is forward-secure after 2 RTTs when the server doesn't require a Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
})
It("doesn't complete the handshake when the handshake timeout is too short", func() {
serverConfig.HandshakeTimeout = 2 * rtt
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
// 2 RTTs during the timeout
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
expectDurationInRTTs(3)
})
})
Context("IETF QUIC", func() {
var clientConfig *quic.Config
var clientTLSConfig *tls.Config
BeforeEach(func() {
serverConfig.Versions = []protocol.VersionNumber{protocol.VersionTLS}
clientConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
clientTLSConfig = &tls.Config{
InsecureSkipVerify: true,
ServerName: "quic.clemente.io",
}
})
// 1 RTT for verifying the source address
// 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("does version negotiation in 1 RTT, gQUIC => IETF QUIC", func() {
clientConfig.Versions = []protocol.VersionNumber{protocol.SupportedVersions[0], protocol.VersionTLS}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
It("is forward-secure after 1 RTTs when the server doesn't require a Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(1)
})
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).To(HaveOccurred())
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.CryptoTooManyRejects))
})
})
})

View File

@ -0,0 +1,128 @@
package self_test
import (
"crypto/tls"
"fmt"
"net"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type versioner interface {
GetVersion() protocol.VersionNumber
}
var _ = Describe("Handshake tests", func() {
var (
server quic.Listener
serverConfig *quic.Config
acceptStopped chan struct{}
)
BeforeEach(func() {
server = nil
acceptStopped = make(chan struct{})
serverConfig = &quic.Config{}
})
AfterEach(func() {
if server != nil {
server.Close()
<-acceptStopped
}
})
runServer := func() {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer close(acceptStopped)
for {
_, err := server.Accept()
if err != nil {
return
}
}
}()
}
Context("Version Negotiation", func() {
var supportedVersions []protocol.VersionNumber
BeforeEach(func() {
supportedVersions = protocol.SupportedVersions
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...)
})
AfterEach(func() {
protocol.SupportedVersions = supportedVersions
})
It("when the server supports more versions than the client", func() {
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = []protocol.VersionNumber{protocol.SupportedVersions[1], 7, 8, 9}
runServer()
sess, err := quic.DialAddr(server.Addr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[1]))
})
It("when the client supports more versions than the server supports", func() {
// the server doesn't support the highest supported version, which is the first one the client will try
// but it supports a bunch of versions that the client doesn't speak
serverConfig.Versions = supportedVersions
runServer()
conf := &quic.Config{
Versions: []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[1], 10},
}
sess, err := quic.DialAddr(server.Addr().String(), &tls.Config{InsecureSkipVerify: true}, conf)
Expect(err).ToNot(HaveOccurred())
Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[1]))
})
})
Context("Certifiate validation", func() {
for _, v := range []protocol.VersionNumber{protocol.Version39, protocol.VersionTLS} {
version := v
Context(fmt.Sprintf("using %s", version), func() {
var clientConfig *quic.Config
BeforeEach(func() {
serverConfig.Versions = []protocol.VersionNumber{version}
clientConfig = &quic.Config{
Versions: []protocol.VersionNumber{version},
}
})
It("accepts the certificate", func() {
runServer()
_, err := quic.DialAddr(fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port), nil, clientConfig)
Expect(err).ToNot(HaveOccurred())
})
It("errors if the server name doesn't match", func() {
runServer()
_, err := quic.DialAddr(fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port), nil, clientConfig)
Expect(err).To(HaveOccurred())
})
It("uses the ServerName in the tls.Config", func() {
runServer()
conf := &tls.Config{ServerName: "quic.clemente.io"}
_, err := quic.DialAddr(fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port), conf, clientConfig)
Expect(err).ToNot(HaveOccurred())
})
})
}
})
})

View File

@ -0,0 +1,232 @@
package self_test
import (
"fmt"
"io/ioutil"
"net"
"os"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Multiplexing", func() {
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
version := v
// gQUIC 44 uses 0 byte connection IDs for packets sent to the client
// It's not possible to do demultiplexing.
if v == protocol.Version44 {
continue
}
Context(fmt.Sprintf("with QUIC version %s", version), func() {
runServer := func(ln quic.Listener) {
go func() {
defer GinkgoRecover()
for {
sess, err := ln.Accept()
if err != nil {
return
}
go func() {
defer GinkgoRecover()
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
defer str.Close()
_, err = str.Write(testserver.PRDataLong)
Expect(err).ToNot(HaveOccurred())
}()
}
}()
}
dial := func(conn net.PacketConn, addr net.Addr) {
sess, err := quic.Dial(
conn,
addr,
fmt.Sprintf("quic.clemente.io:%d", addr.(*net.UDPAddr).Port),
nil,
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testserver.PRDataLong))
}
Context("multiplexing clients on the same conn", func() {
getListener := func() quic.Listener {
ln, err := quic.ListenAddr(
"localhost:0",
testdata.GetTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
return ln
}
It("multiplexes connections to the same server", func() {
server := getListener()
runServer(server)
defer server.Close()
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
close(done2)
}()
timeout := 30 * time.Second
if testlog.Debug() {
timeout = time.Minute
}
Eventually(done1, timeout).Should(BeClosed())
Eventually(done2, timeout).Should(BeClosed())
})
It("multiplexes connections to different servers", func() {
server1 := getListener()
runServer(server1)
defer server1.Close()
server2 := getListener()
runServer(server2)
defer server2.Close()
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server1.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn, server2.Addr())
close(done2)
}()
timeout := 30 * time.Second
if testlog.Debug() {
timeout = time.Minute
}
Eventually(done1, timeout).Should(BeClosed())
Eventually(done2, timeout).Should(BeClosed())
})
})
Context("multiplexing server and client on the same conn", func() {
It("connects to itself", func() {
if version != protocol.VersionTLS {
Skip("Connecting to itself only works with IETF QUIC.")
}
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
server, err := quic.Listen(
conn,
testdata.GetTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
runServer(server)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn, server.Addr())
close(done)
}()
timeout := 30 * time.Second
if testlog.Debug() {
timeout = time.Minute
}
Eventually(done, timeout).Should(BeClosed())
})
It("runs a server and client on the same conn", func() {
if os.Getenv("CI") == "true" {
Skip("This test is flaky on CIs, see see https://github.com/golang/go/issues/17677.")
}
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn1, err := net.ListenUDP("udp", addr1)
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn2, err := net.ListenUDP("udp", addr2)
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
server1, err := quic.Listen(
conn1,
testdata.GetTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
runServer(server1)
defer server1.Close()
server2, err := quic.Listen(
conn2,
testdata.GetTLSConfig(),
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
runServer(server2)
defer server2.Close()
done1 := make(chan struct{})
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
dial(conn2, server1.Addr())
close(done1)
}()
go func() {
defer GinkgoRecover()
dial(conn1, server2.Addr())
close(done2)
}()
timeout := 30 * time.Second
if testlog.Debug() {
timeout = time.Minute
}
Eventually(done1, timeout).Should(BeClosed())
Eventually(done2, timeout).Should(BeClosed())
})
})
})
}
})

View File

@ -0,0 +1,83 @@
package self
import (
"fmt"
"io/ioutil"
"net"
"time"
_ "github.com/lucas-clemente/quic-clients" // download clients
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("non-zero RTT", func() {
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
roundTrips := [...]time.Duration{
10 * time.Millisecond,
50 * time.Millisecond,
100 * time.Millisecond,
200 * time.Millisecond,
}
for _, r := range roundTrips {
rtt := r
It(fmt.Sprintf("downloads a message with %s RTT", rtt), func() {
ln, err := quic.ListenAddr(
"localhost:0",
testdata.GetTLSConfig(),
&quic.Config{
Versions: []protocol.VersionNumber{version},
},
)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := ln.Accept()
Expect(err).ToNot(HaveOccurred())
str, err := sess.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(testserver.PRData)
Expect(err).ToNot(HaveOccurred())
str.Close()
close(done)
}()
serverPort := ln.Addr().(*net.UDPAddr).Port
proxy, err := quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(d quicproxy.Direction, p uint64) time.Duration {
return rtt / 2
},
})
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()
sess, err := quic.DialAddr(
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
nil,
&quic.Config{Versions: []protocol.VersionNumber{version}},
)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(testserver.PRData))
sess.Close()
Eventually(done).Should(BeClosed())
})
}
})
}
})

View File

@ -0,0 +1,20 @@
package self_test
import (
"math/rand"
"testing"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
)
func TestSelf(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Self integration tests")
}
var _ = BeforeSuite(func() {
rand.Seed(GinkgoRandomSeed())
})

View File

@ -0,0 +1,152 @@
package self_test
import (
"fmt"
"io/ioutil"
"net"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Bidirectional streams", func() {
const numStreams = 300
var (
server quic.Listener
serverAddr string
qconf *quic.Config
)
for _, v := range []protocol.VersionNumber{protocol.VersionTLS} {
version := v
Context(fmt.Sprintf("with QUIC %s", version), func() {
BeforeEach(func() {
var err error
qconf = &quic.Config{
Versions: []protocol.VersionNumber{version},
MaxIncomingStreams: 0,
}
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), qconf)
Expect(err).ToNot(HaveOccurred())
serverAddr = fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port)
})
AfterEach(func() {
server.Close()
})
runSendingPeer := func(sess quic.Session) {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := sess.OpenStreamSync()
Expect(err).ToNot(HaveOccurred())
data := testserver.GeneratePRData(25 * i)
go func() {
defer GinkgoRecover()
_, err := str.Write(data)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
go func() {
defer GinkgoRecover()
defer wg.Done()
dataRead, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(dataRead).To(Equal(data))
}()
}
wg.Wait()
}
runReceivingPeer := func(sess quic.Session) {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer wg.Done()
// shouldn't use io.Copy here
// we should read from the stream as early as possible, to free flow control credit
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
_, err = str.Write(data)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
}
wg.Wait()
}
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
var sess quic.Session
go func() {
defer GinkgoRecover()
var err error
sess, err = server.Accept()
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(sess)
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
runSendingPeer(client)
})
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
Expect(err).ToNot(HaveOccurred())
runSendingPeer(sess)
sess.Close()
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(client)
Eventually(client.Context().Done()).Should(BeClosed())
})
It(fmt.Sprintf("client and server opening %d each and sending data to the peer", numStreams), func() {
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
runReceivingPeer(sess)
close(done)
}()
runSendingPeer(sess)
<-done
close(done1)
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
runSendingPeer(client)
close(done2)
}()
runReceivingPeer(client)
<-done1
<-done2
})
})
}
})

View File

@ -0,0 +1,132 @@
package self_test
import (
"fmt"
"io/ioutil"
"net"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Unidirectional Streams", func() {
const numStreams = 500
var (
server quic.Listener
serverAddr string
qconf *quic.Config
)
BeforeEach(func() {
var err error
qconf = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), qconf)
Expect(err).ToNot(HaveOccurred())
serverAddr = fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port)
})
AfterEach(func() {
server.Close()
})
dataForStream := func(id protocol.StreamID) []byte {
return testserver.GeneratePRData(10 * int(id))
}
runSendingPeer := func(sess quic.Session) {
for i := 0; i < numStreams; i++ {
str, err := sess.OpenUniStreamSync()
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
_, err := str.Write(dataForStream(str.StreamID()))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
}()
}
}
runReceivingPeer := func(sess quic.Session) {
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
str, err := sess.AcceptUniStream()
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer wg.Done()
data, err := ioutil.ReadAll(str)
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(dataForStream(str.StreamID())))
}()
}
wg.Wait()
}
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
var sess quic.Session
go func() {
defer GinkgoRecover()
var err error
sess, err = server.Accept()
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(sess)
sess.Close()
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
runSendingPeer(client)
<-client.Context().Done()
})
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
Expect(err).ToNot(HaveOccurred())
runSendingPeer(sess)
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
runReceivingPeer(client)
})
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
sess, err := server.Accept()
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
runReceivingPeer(sess)
close(done)
}()
runSendingPeer(sess)
<-done
close(done1)
}()
client, err := quic.DialAddr(serverAddr, nil, qconf)
Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
runSendingPeer(client)
close(done2)
}()
runReceivingPeer(client)
<-done1
<-done2
})
})

View File

@ -0,0 +1,270 @@
package quicproxy
import (
"net"
"sync"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// Connection is a UDP connection
type connection struct {
ClientAddr *net.UDPAddr // Address of the client
ServerConn *net.UDPConn // UDP connection to server
incomingPacketCounter uint64
outgoingPacketCounter uint64
}
// Direction is the direction a packet is sent.
type Direction int
const (
// DirectionIncoming is the direction from the client to the server.
DirectionIncoming Direction = iota
// DirectionOutgoing is the direction from the server to the client.
DirectionOutgoing
// DirectionBoth is both incoming and outgoing
DirectionBoth
)
func (d Direction) String() string {
switch d {
case DirectionIncoming:
return "incoming"
case DirectionOutgoing:
return "outgoing"
case DirectionBoth:
return "both"
default:
panic("unknown direction")
}
}
// Is says if one direction matches another direction.
// For example, incoming matches both incoming and both, but not outgoing.
func (d Direction) Is(dir Direction) bool {
if d == DirectionBoth || dir == DirectionBoth {
return true
}
return d == dir
}
// DropCallback is a callback that determines which packet gets dropped.
type DropCallback func(dir Direction, packetCount uint64) bool
// NoDropper doesn't drop packets.
var NoDropper DropCallback = func(Direction, uint64) bool {
return false
}
// DelayCallback is a callback that determines how much delay to apply to a packet.
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
// NoDelay doesn't apply a delay.
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
return 0
}
// Opts are proxy options.
type Opts struct {
// The address this proxy proxies packets to.
RemoteAddr string
// DropPacket determines whether a packet gets dropped.
DropPacket DropCallback
// DelayPacket determines how long a packet gets delayed. This allows
// simulating a connection with non-zero RTTs.
// Note that the RTT is the sum of the delay for the incoming and the outgoing packet.
DelayPacket DelayCallback
}
// QuicProxy is a QUIC proxy that can drop and delay packets.
type QuicProxy struct {
mutex sync.Mutex
version protocol.VersionNumber
conn *net.UDPConn
serverAddr *net.UDPAddr
dropPacket DropCallback
delayPacket DelayCallback
// Mapping from client addresses (as host:port) to connection
clientDict map[string]*connection
logger utils.Logger
}
// NewQuicProxy creates a new UDP proxy
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
if opts == nil {
opts = &Opts{}
}
laddr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", laddr)
if err != nil {
return nil, err
}
raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr)
if err != nil {
return nil, err
}
packetDropper := NoDropper
if opts.DropPacket != nil {
packetDropper = opts.DropPacket
}
packetDelayer := NoDelay
if opts.DelayPacket != nil {
packetDelayer = opts.DelayPacket
}
p := QuicProxy{
clientDict: make(map[string]*connection),
conn: conn,
serverAddr: raddr,
dropPacket: packetDropper,
delayPacket: packetDelayer,
version: version,
logger: utils.DefaultLogger.WithPrefix("proxy"),
}
p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
go p.runProxy()
return &p, nil
}
// Close stops the UDP Proxy
func (p *QuicProxy) Close() error {
p.mutex.Lock()
defer p.mutex.Unlock()
for _, c := range p.clientDict {
if err := c.ServerConn.Close(); err != nil {
return err
}
}
return p.conn.Close()
}
// LocalAddr is the address the proxy is listening on.
func (p *QuicProxy) LocalAddr() net.Addr {
return p.conn.LocalAddr()
}
// LocalPort is the UDP port number the proxy is listening on.
func (p *QuicProxy) LocalPort() int {
return p.conn.LocalAddr().(*net.UDPAddr).Port
}
func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
srvudp, err := net.DialUDP("udp", nil, p.serverAddr)
if err != nil {
return nil, err
}
return &connection{
ClientAddr: cliAddr,
ServerConn: srvudp,
}, nil
}
// runProxy listens on the proxy address and handles incoming packets.
func (p *QuicProxy) runProxy() error {
for {
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
if err != nil {
return err
}
raw := buffer[0:n]
saddr := cliaddr.String()
p.mutex.Lock()
conn, ok := p.clientDict[saddr]
if !ok {
conn, err = p.newConnection(cliaddr)
if err != nil {
p.mutex.Unlock()
return err
}
p.clientDict[saddr] = conn
go p.runConnection(conn)
}
p.mutex.Unlock()
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
if p.dropPacket(DirectionIncoming, packetCount) {
if p.logger.Debug() {
p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
}
continue
}
// Send the packet to the server
delay := p.delayPacket(DirectionIncoming, packetCount)
if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
}
time.AfterFunc(delay, func() {
// TODO: handle error
_, _ = conn.ServerConn.Write(raw)
})
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
}
if _, err := conn.ServerConn.Write(raw); err != nil {
return err
}
}
}
}
// runConnection handles packets from server to a single client
func (p *QuicProxy) runConnection(conn *connection) error {
for {
buffer := make([]byte, protocol.MaxReceivePacketSize)
n, err := conn.ServerConn.Read(buffer)
if err != nil {
return err
}
raw := buffer[0:n]
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
if p.dropPacket(DirectionOutgoing, packetCount) {
if p.logger.Debug() {
p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
}
continue
}
delay := p.delayPacket(DirectionOutgoing, packetCount)
if delay != 0 {
if p.logger.Debug() {
p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
}
time.AfterFunc(delay, func() {
// TODO: handle error
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
})
} else {
if p.logger.Debug() {
p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
}
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
return err
}
}
}
}

View File

@ -0,0 +1,13 @@
package quicproxy
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestQuicGo(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "QUIC Proxy")
}

View File

@ -0,0 +1,394 @@
package quicproxy
import (
"bytes"
"net"
"runtime/pprof"
"strconv"
"strings"
"sync/atomic"
"time"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type packetData []byte
var _ = Describe("QUIC Proxy", func() {
makePacket := func(p protocol.PacketNumber, payload []byte) []byte {
b := &bytes.Buffer{}
hdr := wire.Header{
PacketNumber: p,
PacketNumberLen: protocol.PacketNumberLen6,
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
}
hdr.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever)
raw := b.Bytes()
raw = append(raw, payload...)
return raw
}
Context("Proxy setup and teardown", func() {
It("sets up the UDPProxy", func() {
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
Expect(err).ToNot(HaveOccurred())
Expect(proxy.clientDict).To(HaveLen(0))
// check that the proxy port is in use
addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(proxy.LocalPort()))
Expect(err).ToNot(HaveOccurred())
_, err = net.ListenUDP("udp", addr)
Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: address already in use", proxy.LocalPort())))
Expect(proxy.Close()).To(Succeed()) // stopping is tested in the next test
})
It("stops the UDPProxy", func() {
isProxyRunning := func() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "proxy.(*QuicProxy).runProxy")
}
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
Expect(err).ToNot(HaveOccurred())
port := proxy.LocalPort()
Expect(isProxyRunning()).To(BeTrue())
err = proxy.Close()
Expect(err).ToNot(HaveOccurred())
// check that the proxy port is not in use anymore
addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(port))
Expect(err).ToNot(HaveOccurred())
// sometimes it takes a while for the OS to free the port
Eventually(func() error {
ln, err := net.ListenUDP("udp", addr)
if err != nil {
return err
}
ln.Close()
return nil
}).ShouldNot(HaveOccurred())
Eventually(isProxyRunning).Should(BeFalse())
})
It("stops listening for proxied connections", func() {
isConnRunning := func() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "proxy.(*QuicProxy).runConnection")
}
serverAddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
serverConn, err := net.ListenUDP("udp", serverAddr)
Expect(err).ToNot(HaveOccurred())
defer serverConn.Close()
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, &Opts{RemoteAddr: serverConn.LocalAddr().String()})
Expect(err).ToNot(HaveOccurred())
Expect(isConnRunning()).To(BeFalse())
// check that the proxy port is not in use anymore
conn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
Expect(err).ToNot(HaveOccurred())
_, err = conn.Write(makePacket(1, []byte("foobar")))
Expect(err).ToNot(HaveOccurred())
Eventually(isConnRunning).Should(BeTrue())
Expect(proxy.Close()).To(Succeed())
Eventually(isConnRunning).Should(BeFalse())
})
It("has the correct LocalAddr and LocalPort", func() {
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
Expect(err).ToNot(HaveOccurred())
Expect(proxy.LocalAddr().String()).To(Equal("127.0.0.1:" + strconv.Itoa(proxy.LocalPort())))
Expect(proxy.LocalPort()).ToNot(BeZero())
Expect(proxy.Close()).To(Succeed())
})
})
Context("Proxy tests", func() {
var (
serverConn *net.UDPConn
serverNumPacketsSent int32
serverReceivedPackets chan packetData
clientConn *net.UDPConn
proxy *QuicProxy
)
startProxy := func(opts *Opts) {
var err error
proxy, err = NewQuicProxy("localhost:0", protocol.VersionWhatever, opts)
Expect(err).ToNot(HaveOccurred())
clientConn, err = net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
Expect(err).ToNot(HaveOccurred())
}
// getClientDict returns a copy of the clientDict map
getClientDict := func() map[string]*connection {
d := make(map[string]*connection)
proxy.mutex.Lock()
defer proxy.mutex.Unlock()
for k, v := range proxy.clientDict {
d[k] = v
}
return d
}
BeforeEach(func() {
serverReceivedPackets = make(chan packetData, 100)
atomic.StoreInt32(&serverNumPacketsSent, 0)
// setup a dump UDP server
// in production this would be a QUIC server
raddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
Expect(err).ToNot(HaveOccurred())
serverConn, err = net.ListenUDP("udp", raddr)
Expect(err).ToNot(HaveOccurred())
go func() {
for {
buf := make([]byte, protocol.MaxReceivePacketSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, addr, err2 := serverConn.ReadFromUDP(buf)
if err2 != nil {
return
}
data := buf[0:n]
serverReceivedPackets <- packetData(data)
// echo the packet
serverConn.WriteToUDP(data, addr)
atomic.AddInt32(&serverNumPacketsSent, 1)
}
}()
})
AfterEach(func() {
err := proxy.Close()
Expect(err).ToNot(HaveOccurred())
err = serverConn.Close()
Expect(err).ToNot(HaveOccurred())
err = clientConn.Close()
Expect(err).ToNot(HaveOccurred())
time.Sleep(200 * time.Millisecond)
})
Context("no packet drop", func() {
It("relays packets from the client to the server", func() {
startProxy(&Opts{RemoteAddr: serverConn.LocalAddr().String()})
// send the first packet
_, err := clientConn.Write(makePacket(1, []byte("foobar")))
Expect(err).ToNot(HaveOccurred())
Eventually(getClientDict).Should(HaveLen(1))
var conn *connection
for _, conn = range getClientDict() {
Eventually(func() uint64 { return atomic.LoadUint64(&conn.incomingPacketCounter) }).Should(Equal(uint64(1)))
}
// send the second packet
_, err = clientConn.Write(makePacket(2, []byte("decafbad")))
Expect(err).ToNot(HaveOccurred())
Eventually(serverReceivedPackets).Should(HaveLen(2))
Expect(getClientDict()).To(HaveLen(1))
Expect(string(<-serverReceivedPackets)).To(ContainSubstring("foobar"))
Expect(string(<-serverReceivedPackets)).To(ContainSubstring("decafbad"))
})
It("relays packets from the server to the client", func() {
startProxy(&Opts{RemoteAddr: serverConn.LocalAddr().String()})
// send the first packet
_, err := clientConn.Write(makePacket(1, []byte("foobar")))
Expect(err).ToNot(HaveOccurred())
Eventually(getClientDict).Should(HaveLen(1))
var key string
var conn *connection
for key, conn = range getClientDict() {
Eventually(func() uint64 { return atomic.LoadUint64(&conn.outgoingPacketCounter) }).Should(Equal(uint64(1)))
}
// send the second packet
_, err = clientConn.Write(makePacket(2, []byte("decafbad")))
Expect(err).ToNot(HaveOccurred())
Expect(getClientDict()).To(HaveLen(1))
Eventually(func() uint64 {
conn := getClientDict()[key]
return atomic.LoadUint64(&conn.outgoingPacketCounter)
}).Should(BeEquivalentTo(2))
clientReceivedPackets := make(chan packetData, 2)
// receive the packets echoed by the server on client side
go func() {
for {
buf := make([]byte, protocol.MaxReceivePacketSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, _, err2 := clientConn.ReadFromUDP(buf)
if err2 != nil {
return
}
data := buf[0:n]
clientReceivedPackets <- packetData(data)
}
}()
Eventually(serverReceivedPackets).Should(HaveLen(2))
Expect(atomic.LoadInt32(&serverNumPacketsSent)).To(BeEquivalentTo(2))
Eventually(clientReceivedPackets).Should(HaveLen(2))
Expect(string(<-clientReceivedPackets)).To(ContainSubstring("foobar"))
Expect(string(<-clientReceivedPackets)).To(ContainSubstring("decafbad"))
})
})
Context("Drop Callbacks", func() {
It("drops incoming packets", func() {
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
DropPacket: func(d Direction, p uint64) bool {
return d == DirectionIncoming && p%2 == 0
},
}
startProxy(opts)
for i := 1; i <= 6; i++ {
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
Expect(err).ToNot(HaveOccurred())
}
Eventually(serverReceivedPackets).Should(HaveLen(3))
Consistently(serverReceivedPackets).Should(HaveLen(3))
})
It("drops outgoing packets", func() {
const numPackets = 6
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
DropPacket: func(d Direction, p uint64) bool {
return d == DirectionOutgoing && p%2 == 0
},
}
startProxy(opts)
clientReceivedPackets := make(chan packetData, numPackets)
// receive the packets echoed by the server on client side
go func() {
for {
buf := make([]byte, protocol.MaxReceivePacketSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, _, err2 := clientConn.ReadFromUDP(buf)
if err2 != nil {
return
}
data := buf[0:n]
clientReceivedPackets <- packetData(data)
}
}()
for i := 1; i <= numPackets; i++ {
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
Expect(err).ToNot(HaveOccurred())
}
Eventually(clientReceivedPackets).Should(HaveLen(numPackets / 2))
Consistently(clientReceivedPackets).Should(HaveLen(numPackets / 2))
})
})
Context("Delay Callback", func() {
expectDelay := func(startTime time.Time, rtt time.Duration, numRTTs int) {
expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * rtt)
Expect(time.Now()).To(SatisfyAll(
BeTemporally(">=", expectedReceiveTime),
BeTemporally("<", expectedReceiveTime.Add(rtt/2)),
))
}
It("delays incoming packets", func() {
delay := 300 * time.Millisecond
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
// delay packet 1 by 200 ms
// delay packet 2 by 400 ms
// ...
DelayPacket: func(d Direction, p uint64) time.Duration {
if d == DirectionOutgoing {
return 0
}
return time.Duration(p) * delay
},
}
startProxy(opts)
// send 3 packets
start := time.Now()
for i := 1; i <= 3; i++ {
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
Expect(err).ToNot(HaveOccurred())
}
Eventually(serverReceivedPackets).Should(HaveLen(1))
expectDelay(start, delay, 1)
Eventually(serverReceivedPackets).Should(HaveLen(2))
expectDelay(start, delay, 2)
Eventually(serverReceivedPackets).Should(HaveLen(3))
expectDelay(start, delay, 3)
})
It("delays outgoing packets", func() {
const numPackets = 3
delay := 300 * time.Millisecond
opts := &Opts{
RemoteAddr: serverConn.LocalAddr().String(),
// delay packet 1 by 200 ms
// delay packet 2 by 400 ms
// ...
DelayPacket: func(d Direction, p uint64) time.Duration {
if d == DirectionIncoming {
return 0
}
return time.Duration(p) * delay
},
}
startProxy(opts)
clientReceivedPackets := make(chan packetData, numPackets)
// receive the packets echoed by the server on client side
go func() {
for {
buf := make([]byte, protocol.MaxReceivePacketSize)
// the ReadFromUDP will error as soon as the UDP conn is closed
n, _, err2 := clientConn.ReadFromUDP(buf)
if err2 != nil {
return
}
data := buf[0:n]
clientReceivedPackets <- packetData(data)
}
}()
start := time.Now()
for i := 1; i <= numPackets; i++ {
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
Expect(err).ToNot(HaveOccurred())
}
// the packets should have arrived immediately at the server
Eventually(serverReceivedPackets).Should(HaveLen(3))
expectDelay(start, delay, 0)
Eventually(clientReceivedPackets).Should(HaveLen(1))
expectDelay(start, delay, 1)
Eventually(clientReceivedPackets).Should(HaveLen(2))
expectDelay(start, delay, 2)
Eventually(clientReceivedPackets).Should(HaveLen(3))
expectDelay(start, delay, 3)
})
})
})
})

View File

@ -0,0 +1,46 @@
package testlog
import (
"flag"
"log"
"os"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var (
logFileName string // the log file set in the ginkgo flags
logFile *os.File
)
// read the logfile command line flag
// to set call ginkgo -- -logfile=log.txt
func init() {
flag.StringVar(&logFileName, "logfile", "", "log file")
}
var _ = BeforeEach(func() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
if len(logFileName) > 0 {
var err error
logFile, err = os.Create(logFileName)
Expect(err).ToNot(HaveOccurred())
log.SetOutput(logFile)
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
}
})
var _ = AfterEach(func() {
if len(logFileName) > 0 {
_ = logFile.Close()
}
})
// Debug says if this test is being logged
func Debug() bool {
return len(logFileName) > 0
}

View File

@ -0,0 +1,119 @@
package testserver
import (
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const (
dataLen = 500 * 1024 // 500 KB
dataLenLong = 50 * 1024 * 1024 // 50 MB
)
var (
// PRData contains dataLen bytes of pseudo-random data.
PRData = GeneratePRData(dataLen)
// PRDataLong contains dataLenLong bytes of pseudo-random data.
PRDataLong = GeneratePRData(dataLenLong)
server *h2quic.Server
stoppedServing chan struct{}
port string
)
func init() {
http.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
sl := r.URL.Query().Get("len")
if sl != "" {
var err error
l, err := strconv.Atoi(sl)
Expect(err).NotTo(HaveOccurred())
_, err = w.Write(GeneratePRData(l))
Expect(err).NotTo(HaveOccurred())
} else {
_, err := w.Write(PRData)
Expect(err).NotTo(HaveOccurred())
}
})
http.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
_, err := w.Write(PRDataLong)
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
_, err := io.WriteString(w, "Hello, World!\n")
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
body, err := ioutil.ReadAll(r.Body)
Expect(err).NotTo(HaveOccurred())
_, err = w.Write(body)
Expect(err).NotTo(HaveOccurred())
})
}
// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
func GeneratePRData(l int) []byte {
res := make([]byte, l)
seed := uint64(1)
for i := 0; i < l; i++ {
seed = seed * 48271 % 2147483647
res[i] = byte(seed)
}
return res
}
// StartQuicServer starts a h2quic.Server.
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
func StartQuicServer(versions []protocol.VersionNumber) {
server = &h2quic.Server{
Server: &http.Server{
TLSConfig: testdata.GetTLSConfig(),
},
QuicConfig: &quic.Config{
Versions: versions,
},
}
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
Expect(err).NotTo(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)
Expect(err).NotTo(HaveOccurred())
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
stoppedServing = make(chan struct{})
go func() {
defer GinkgoRecover()
server.Serve(conn)
close(stoppedServing)
}()
}
// StopQuicServer stops the h2quic.Server.
func StopQuicServer() {
Expect(server.Close()).NotTo(HaveOccurred())
Eventually(stoppedServing).Should(BeClosed())
}
// Port returns the UDP port of the QUIC server.
func Port() string {
return port
}

View File

@ -0,0 +1,223 @@
package quic
import (
"context"
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// The StreamID is the ID of a QUIC stream.
type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber
const (
// VersionGQUIC39 is gQUIC version 39.
VersionGQUIC39 = protocol.Version39
// VersionGQUIC43 is gQUIC version 43.
VersionGQUIC43 = protocol.Version43
// VersionGQUIC44 is gQUIC version 44.
VersionGQUIC44 = protocol.Version44
// VersionMilestone0_10_0 uses TLS
VersionMilestone0_10_0 = protocol.VersionMilestone0_10_0
)
// A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie
// ConnectionState records basic details about the QUIC connection.
type ConnectionState = handshake.ConnectionState
// An ErrorCode is an application-defined error code.
type ErrorCode = protocol.ApplicationErrorCode
// Stream is the interface implemented by QUIC streams
type Stream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Reader
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Writer
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
// It must not be called concurrently with Write.
// It must not be called after calling CancelWrite.
io.Closer
// CancelWrite aborts sending on this stream.
// It must not be called after Close.
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
CancelWrite(ErrorCode) error
// CancelRead aborts receiving on this stream.
// It will ask the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail.
CancelRead(ErrorCode) error
// The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means Write will not time out.
SetWriteDeadline(t time.Time) error
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
SetDeadline(t time.Time) error
}
// A ReceiveStream is a unidirectional Receive Stream.
type ReceiveStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Read
io.Reader
// see Stream.CancelRead
CancelRead(ErrorCode) error
// see Stream.SetReadDealine
SetReadDeadline(t time.Time) error
}
// A SendStream is a unidirectional Send Stream.
type SendStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Write
io.Writer
// see Stream.Close
io.Closer
// see Stream.CancelWrite
CancelWrite(ErrorCode) error
// see Stream.Context
Context() context.Context
// see Stream.SetWriteDeadline
SetWriteDeadline(t time.Time) error
}
// StreamError is returned by Read and Write when the peer cancels the stream.
type StreamError interface {
error
Canceled() bool
ErrorCode() ErrorCode
}
// A Session is a QUIC connection between two peers.
type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
AcceptStream() (Stream, error)
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
AcceptUniStream() (ReceiveStream, error)
// OpenStream opens a new bidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
// TODO(#1152): Enable testing for the special error
OpenStream() (Stream, error)
// OpenStreamSync opens a new bidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenStreamSync() (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// TODO(#1152): Enable testing for the special error
OpenUniStream() (SendStream, error)
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenUniStreamSync() (SendStream, error)
// LocalAddr returns the local address.
LocalAddr() net.Addr
// RemoteAddr returns the address of the peer.
RemoteAddr() net.Addr
// Close the connection.
io.Closer
// Close the connection with an error.
// The error must not be nil.
CloseWithError(ErrorCode, error) error
// The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
// ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
ConnectionState() ConnectionState
}
// Config contains all configuration data needed for a QUIC server or client.
type Config struct {
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon.
Versions []VersionNumber
// Ask the server to omit the connection ID sent in the Public Header.
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client.
RequestConnectionIDOmission bool
// The length of the connection ID in bytes. Only valid for IETF QUIC.
// It can be 0, or any value between 4 and 18.
// If not set, the interpretation depends on where the Config is used:
// If used for dialing an address, a 0 byte connection ID will be used.
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.
HandshakeTimeout time.Duration
// IdleTimeout is the maximum duration that may pass without any incoming network activity.
// This value only applies after the handshake has completed.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 30 seconds.
IdleTimeout time.Duration
// AcceptCookie determines if a Cookie is accepted.
// It is called with cookie = nil if the client didn't send an Cookie.
// If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours.
// This option is only valid for the server.
AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
MaxReceiveStreamFlowControlWindow uint64
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow uint64
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingStreams int
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// This value doesn't have any effect in Google QUIC.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingUniStreams int
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool
}
// A Listener for incoming QUIC connections
type Listener interface {
// Close the server, sending CONNECTION_CLOSE frames to each peer.
Close() error
// Addr returns the local network addr that the server is listening on.
Addr() net.Addr
// Accept returns new sessions. It should be called in a loop.
Accept() (Session, error)
}

View File

@ -0,0 +1,24 @@
package ackhandler
import (
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestCrypto(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "AckHandler Suite")
}
var mockCtrl *gomock.Controller
var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT())
})
var _ = AfterEach(func() {
mockCtrl.Finish()
})

View File

@ -0,0 +1,3 @@
package ackhandler
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet

View File

@ -0,0 +1,47 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet)
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete()
// The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode
// TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets.
TimeUntilSend() time.Time
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
// It always returns a number greater or equal than 1.
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
// Note that the number of packets is only calculated based on the pacing algorithm.
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
ShouldSendNumPackets() int
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() *Packet
DequeueProbePacket() (*Packet, error)
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
GetAlarmTimeout() time.Time
OnAlarm() error
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *wire.AckFrame
}

View File

@ -0,0 +1,29 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// A Packet is a packet
type Packet struct {
PacketNumber protocol.PacketNumber
PacketType protocol.PacketType
Frames []wire.Frame
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
// There are two reasons why a packet cannot be retransmitted:
// * it was already retransmitted
// * this packet is a retransmission, and we already received an ACK for the original packet
canBeRetransmitted bool
includedInBytesInFlight bool
retransmittedAs []protocol.PacketNumber
isRetransmission bool // we need a separate bool here because 0 is a valid packet number
retransmissionOf protocol.PacketNumber
}

View File

@ -0,0 +1,217 @@
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny
package ackhandler
// Linked list implementation from the Go standard library.
// PacketElement is an element of a linked list.
type PacketElement struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *PacketElement
// The list to which this element belongs.
list *PacketList
// The value stored with this element.
Value Packet
}
// Next returns the next list element or nil.
func (e *PacketElement) Next() *PacketElement {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *PacketElement) Prev() *PacketElement {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// PacketList is a linked list of Packets.
type PacketList struct {
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *PacketList) Init() *PacketList {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewPacketList returns an initialized list.
func NewPacketList() *PacketList { return new(PacketList).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *PacketList) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *PacketList) Front() *PacketElement {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *PacketList) Back() *PacketElement {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *PacketList) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
n := at.next
at.next = e
e.prev = at
e.next = n
n.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
return l.insert(&PacketElement{Value: v}, at)
}
// remove removes e from its list, decrements l.len, and returns e.
func (l *PacketList) remove(e *PacketElement) *PacketElement {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *PacketList) Remove(e *PacketElement) Packet {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *PacketList) PushFront(v Packet) *PacketElement {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *PacketList) PushBack(v Packet) *PacketElement {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToFront(e *PacketElement) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToBack(e *PacketElement) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveBefore(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveAfter(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark)
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushBackList(other *PacketList) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushFrontList(other *PacketList) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}

View File

@ -0,0 +1,215 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type receivedPacketHandler struct {
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
ackSendDelay time.Duration
rttStats *congestion.RTTStats
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
version protocol.VersionNumber
}
const (
// maximum delay that can be applied to an ACK for a retransmittable packet
ackSendDelay = 25 * time.Millisecond
// initial maximum number of retransmittable packets received before sending an ack.
initialRetransmittablePacketsBeforeAck = 2
// number of retransmittable that an ACK is sent for
retransmittablePacketsBeforeAck = 10
// 1/5 RTT delay when doing ack decimation
ackDecimationDelay = 1.0 / 4
// 1/8 RTT delay when doing ack decimation
shortAckDecimationDelay = 1.0 / 8
// Minimum number of packets received before ack decimation is enabled.
// This intends to avoid the beginning of slow start, when CWNDs may be
// rapidly increasing.
minReceivedBeforeAckDecimation = 100
// Maximum number of packets to ack immediately after a missing packet for
// fast retransmission to kick in at the sender. This limit is created to
// reduce the number of acks sent that have no benefit for fast retransmission.
// Set to the number of nacks needed for fast retransmit plus one for protection
// against an ack loss
maxPacketsAfterNewMissing = 4
)
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(
rttStats *congestion.RTTStats,
logger utils.Logger,
version protocol.VersionNumber,
) ReceivedPacketHandler {
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: ackSendDelay,
rttStats: rttStats,
logger: logger,
version: version,
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber < h.ignoreBelow {
return nil
}
isMissing := h.isMissing(packetNumber)
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = rcvTime
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
return nil
}
// IgnoreBelow sets a lower limit for acking packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
if p <= h.ignoreBelow {
return
}
h.ignoreBelow = p
h.packetHistory.DeleteBelow(p)
if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
}
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
}
// maybeQueueAck queues an ACK, if necessary.
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
// in ACK_DECIMATION_WITH_REORDERING mode.
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
h.packetsReceivedSinceLastAck++
// always ack the first packet
if h.lastAck == nil {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
h.ackQueued = true
return
}
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
}
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
if packetNumber > minReceivedBeforeAckDecimation {
// ack up to 10 packets at once
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
h.ackQueued = true
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
}
} else if h.ackAlarm.IsZero() {
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
h.ackAlarm = rcvTime.Add(ackDelay)
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
} else {
// send an ACK every 2 retransmittable packets
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
}
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
}
h.ackAlarm = rcvTime.Add(ackSendDelay)
}
}
// If there are new missing packets to report, set a short timer to send an ACK.
if h.hasNewMissingPackets() {
// wait the minimum of 1/8 min RTT and the existing ack time
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
ackTime := rcvTime.Add(ackDelay)
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
h.ackAlarm = ackTime
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
}
}
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
now := time.Now()
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
return nil
}
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
ack := &wire.AckFrame{
AckRanges: h.packetHistory.GetAckRanges(),
DelayTime: now.Sub(h.largestObservedReceivedTime),
}
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -0,0 +1,382 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("receivedPacketHandler", func() {
var (
handler *receivedPacketHandler
rttStats *congestion.RTTStats
)
BeforeEach(func() {
rttStats = &congestion.RTTStats{}
handler = NewReceivedPacketHandler(rttStats, utils.DefaultLogger, protocol.VersionWhatever).(*receivedPacketHandler)
})
Context("accepting packets", func() {
It("handles a packet that arrives late", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(3), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(2), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
})
It("saves the time when each packet arrived", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(3), time.Now(), true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
})
It("updates the largestObserved and the largestObservedReceivedTime", func() {
now := time.Now()
handler.largestObserved = 3
handler.largestObservedReceivedTime = now.Add(-1 * time.Second)
err := handler.ReceivedPacket(5, now, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(Equal(now))
})
It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() {
now := time.Now()
timestamp := now.Add(-1 * time.Second)
handler.largestObserved = 5
handler.largestObservedReceivedTime = timestamp
err := handler.ReceivedPacket(4, now, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(Equal(timestamp))
})
It("passes on errors from receivedPacketHistory", func() {
var err error
for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ {
err = handler.ReceivedPacket(2*i+1, time.Time{}, true)
// this will eventually return an error
// details about when exactly the receivedPacketHistory errors are tested there
if err != nil {
break
}
}
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
})
})
Context("ACKs", func() {
Context("queueing ACKs", func() {
receiveAndAck10Packets := func() {
for i := 1; i <= 10; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
}
Expect(handler.GetAckFrame()).ToNot(BeNil())
Expect(handler.ackQueued).To(BeFalse())
}
receiveAndAckPacketsUntilAckDecimation := func() {
for i := 1; i <= minReceivedBeforeAckDecimation; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
}
Expect(handler.GetAckFrame()).ToNot(BeNil())
Expect(handler.ackQueued).To(BeFalse())
}
It("always queues an ACK for the first packet", func() {
err := handler.ReceivedPacket(1, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero())
})
It("works with packet number 0", func() {
err := handler.ReceivedPacket(0, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero())
})
It("queues an ACK for every second retransmittable packet at the beginning", func() {
receiveAndAck10Packets()
p := protocol.PacketNumber(11)
for i := 0; i <= 20; i++ {
err := handler.ReceivedPacket(p, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
p++
err = handler.ReceivedPacket(p, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
p++
// dequeue the ACK frame
Expect(handler.GetAckFrame()).ToNot(BeNil())
}
})
It("queues an ACK for every 10 retransmittable packet, if they are arriving fast", func() {
receiveAndAck10Packets()
p := protocol.PacketNumber(10000)
for i := 0; i < 9; i++ {
err := handler.ReceivedPacket(p, time.Now(), true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
p++
}
Expect(handler.GetAlarmTimeout()).NotTo(BeZero())
err := handler.ReceivedPacket(p, time.Now(), true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(handler.GetAlarmTimeout()).To(BeZero())
})
It("only sets the timer when receiving a retransmittable packets", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, time.Now(), false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.GetAlarmTimeout()).To(BeZero())
rcvTime := time.Now().Add(10 * time.Millisecond)
err = handler.ReceivedPacket(12, rcvTime, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.GetAlarmTimeout()).To(Equal(rcvTime.Add(ackSendDelay)))
})
It("queues an ACK if it was reported missing before", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(13, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() // ACK: 1-11 and 13, missing: 12
Expect(ack).ToNot(BeNil())
Expect(ack.HasMissingRanges()).To(BeTrue())
Expect(handler.ackQueued).To(BeFalse())
err = handler.ReceivedPacket(12, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
})
It("doesn't queue an ACK if it was reported missing before, but is below the threshold", func() {
receiveAndAck10Packets()
// 11 is missing
err := handler.ReceivedPacket(12, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(13, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() // ACK: 1-10, 12-13
Expect(ack).ToNot(BeNil())
// now receive 11
handler.IgnoreBelow(12)
err = handler.ReceivedPacket(11, time.Time{}, false)
Expect(err).ToNot(HaveOccurred())
ack = handler.GetAckFrame()
Expect(ack).To(BeNil())
})
It("doesn't queue an ACK if the packet closes a gap that was not yet reported", func() {
receiveAndAckPacketsUntilAckDecimation()
p := protocol.PacketNumber(minReceivedBeforeAckDecimation + 1)
err := handler.ReceivedPacket(p+1, time.Now(), true) // p is missing now
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.GetAlarmTimeout()).ToNot(BeZero())
err = handler.ReceivedPacket(p, time.Now(), true) // p is not missing any more
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
})
It("sets an ACK alarm after 1/4 RTT if it creates a new missing range", func() {
now := time.Now().Add(-time.Hour)
rtt := 80 * time.Millisecond
rttStats.UpdateRTT(rtt, 0, now)
receiveAndAckPacketsUntilAckDecimation()
p := protocol.PacketNumber(minReceivedBeforeAckDecimation + 1)
for i := p; i < p+6; i++ {
err := handler.ReceivedPacket(i, now, true)
Expect(err).ToNot(HaveOccurred())
}
err := handler.ReceivedPacket(p+10, now, true) // we now know that packets p+7, p+8 and p+9
Expect(err).ToNot(HaveOccurred())
Expect(rttStats.MinRTT()).To(Equal(rtt))
Expect(handler.ackAlarm.Sub(now)).To(Equal(rtt / 8))
ack := handler.GetAckFrame()
Expect(ack.HasMissingRanges()).To(BeTrue())
Expect(ack).ToNot(BeNil())
})
})
Context("ACK generation", func() {
BeforeEach(func() {
handler.ackQueued = true
})
It("generates a simple ACK frame", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1)))
Expect(ack.HasMissingRanges()).To(BeFalse())
})
It("generates an ACK for packet number 0", func() {
err := handler.ReceivedPacket(0, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(0)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0)))
Expect(ack.HasMissingRanges()).To(BeFalse())
})
It("sets the delay time", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(2, time.Now().Add(-1337*time.Millisecond), true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.DelayTime).To(BeNumerically("~", 1337*time.Millisecond, 50*time.Millisecond))
})
It("saves the last sent ACK", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(handler.lastAck).To(Equal(ack))
err = handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = true
ack = handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(handler.lastAck).To(Equal(ack))
})
It("generates an ACK frame with missing packets", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(4, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1)))
Expect(ack.AckRanges).To(Equal([]wire.AckRange{
{Smallest: 4, Largest: 4},
{Smallest: 1, Largest: 1},
}))
})
It("generates an ACK for packet number 0 and other packets", func() {
err := handler.ReceivedPacket(0, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(3, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0)))
Expect(ack.AckRanges).To(Equal([]wire.AckRange{
{Smallest: 3, Largest: 3},
{Smallest: 0, Largest: 1},
}))
})
It("accepts packets below the lower limit", func() {
handler.IgnoreBelow(6)
err := handler.ReceivedPacket(2, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
})
It("doesn't add delayed packets to the packetHistory", func() {
handler.IgnoreBelow(7)
err := handler.ReceivedPacket(4, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(10, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(10)))
})
It("deletes packets from the packetHistory when a lower limit is set", func() {
for i := 1; i <= 12; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
}
handler.IgnoreBelow(7)
// check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(12)))
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(7)))
Expect(ack.HasMissingRanges()).To(BeFalse())
})
// TODO: remove this test when dropping support for STOP_WAITINGs
It("handles a lower limit of 0", func() {
handler.IgnoreBelow(0)
err := handler.ReceivedPacket(1337, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(1337)))
})
It("resets all counters needed for the ACK queueing decision when sending an ACK", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackAlarm = time.Now().Add(-time.Minute)
Expect(handler.GetAckFrame()).ToNot(BeNil())
Expect(handler.packetsReceivedSinceLastAck).To(BeZero())
Expect(handler.GetAlarmTimeout()).To(BeZero())
Expect(handler.retransmittablePacketsReceivedSinceLastAck).To(BeZero())
Expect(handler.ackQueued).To(BeFalse())
})
It("doesn't generate an ACK when none is queued and the timer is not set", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Time{}
Expect(handler.GetAckFrame()).To(BeNil())
})
It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Now().Add(time.Minute)
Expect(handler.GetAckFrame()).To(BeNil())
})
It("generates an ACK when the timer has expired", func() {
err := handler.ReceivedPacket(1, time.Time{}, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Now().Add(-time.Minute)
Expect(handler.GetAckFrame()).ToNot(BeNil())
})
})
})
})

View File

@ -0,0 +1,121 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
// The receivedPacketHistory stores if a packet number has already been received.
// It does not store packet contents.
type receivedPacketHistory struct {
ranges *utils.PacketIntervalList
lowestInReceivedPacketNumbers protocol.PacketNumber
}
var errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
// newReceivedPacketHistory creates a new received packet history
func newReceivedPacketHistory() *receivedPacketHistory {
return &receivedPacketHistory{
ranges: utils.NewPacketIntervalList(),
}
}
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
if h.ranges.Len() >= protocol.MaxTrackedReceivedAckRanges {
return errTooManyOutstandingReceivedAckRanges
}
if h.ranges.Len() == 0 {
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
return nil
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
// p already included in an existing range. Nothing to do here
if p >= el.Value.Start && p <= el.Value.End {
return nil
}
var rangeExtended bool
if el.Value.End == p-1 { // extend a range at the end
rangeExtended = true
el.Value.End = p
} else if el.Value.Start == p+1 { // extend a range at the beginning
rangeExtended = true
el.Value.Start = p
}
// if a range was extended (either at the beginning or at the end, maybe it is possible to merge two ranges into one)
if rangeExtended {
prev := el.Prev()
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
prev.Value.End = el.Value.End
h.ranges.Remove(el)
return nil
}
return nil // if the two ranges were not merge, we're done here
}
// create a new range at the end
if p > el.Value.End {
h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el)
return nil
}
}
// create a new range at the beginning
h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front())
return nil
}
// DeleteBelow deletes all entries below (but not including) p
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
if p <= h.lowestInReceivedPacketNumbers {
return
}
h.lowestInReceivedPacketNumbers = p
nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next()
if p > el.Value.Start && p <= el.Value.End {
el.Value.Start = p
} else if el.Value.End < p { // delete a whole range
h.ranges.Remove(el)
} else { // no ranges affected. Nothing to do
return
}
}
}
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
if h.ranges.Len() == 0 {
return nil
}
ackRanges := make([]wire.AckRange, h.ranges.Len())
i := 0
for el := h.ranges.Back(); el != nil; el = el.Prev() {
ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}
i++
}
return ackRanges
}
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.Smallest = r.Start
ackRange.Largest = r.End
}
return ackRange
}

View File

@ -0,0 +1,248 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("receivedPacketHistory", func() {
var (
hist *receivedPacketHistory
)
BeforeEach(func() {
hist = newReceivedPacketHistory()
})
Context("ranges", func() {
It("adds the first packet", func() {
hist.ReceivedPacket(4)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
})
It("doesn't care about duplicate packets", func() {
hist.ReceivedPacket(4)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
})
It("adds a few consecutive packets", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(6)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
})
It("doesn't care about a duplicate packet contained in an existing range", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(6)
hist.ReceivedPacket(5)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
})
It("extends a range at the front", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(3)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 3, End: 4}))
})
It("creates a new range when a packet is lost", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(6)
Expect(hist.ranges.Len()).To(Equal(2))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6}))
})
It("creates a new range in between two ranges", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(10)
Expect(hist.ranges.Len()).To(Equal(2))
hist.ReceivedPacket(7)
Expect(hist.ranges.Len()).To(Equal(3))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
})
It("creates a new range before an existing range for a belated packet", func() {
hist.ReceivedPacket(6)
hist.ReceivedPacket(4)
Expect(hist.ranges.Len()).To(Equal(2))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6}))
})
It("extends a previous range at the end", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(7)
hist.ReceivedPacket(5)
Expect(hist.ranges.Len()).To(Equal(2))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 5}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7}))
})
It("extends a range at the front", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(7)
hist.ReceivedPacket(6)
Expect(hist.ranges.Len()).To(Equal(2))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 7}))
})
It("closes a range", func() {
hist.ReceivedPacket(6)
hist.ReceivedPacket(4)
Expect(hist.ranges.Len()).To(Equal(2))
hist.ReceivedPacket(5)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
})
It("closes a range in the middle", func() {
hist.ReceivedPacket(1)
hist.ReceivedPacket(10)
hist.ReceivedPacket(4)
hist.ReceivedPacket(6)
Expect(hist.ranges.Len()).To(Equal(4))
hist.ReceivedPacket(5)
Expect(hist.ranges.Len()).To(Equal(3))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 1, End: 1}))
Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
})
})
Context("deleting", func() {
It("does nothing when the history is empty", func() {
hist.DeleteBelow(5)
Expect(hist.ranges.Len()).To(BeZero())
})
It("deletes a range", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(10)
hist.DeleteBelow(6)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
})
It("deletes multiple ranges", func() {
hist.ReceivedPacket(1)
hist.ReceivedPacket(5)
hist.ReceivedPacket(10)
hist.DeleteBelow(8)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
})
It("adjusts a range, if packets are delete from an existing range", func() {
hist.ReceivedPacket(3)
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(6)
hist.ReceivedPacket(7)
hist.DeleteBelow(5)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 7}))
})
It("adjusts a range, if only one packet remains in the range", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(10)
hist.DeleteBelow(5)
Expect(hist.ranges.Len()).To(Equal(2))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 5}))
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
})
It("keeps a one-packet range, if deleting up to the packet directly below", func() {
hist.ReceivedPacket(4)
hist.DeleteBelow(4)
Expect(hist.ranges.Len()).To(Equal(1))
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
})
Context("DoS protection", func() {
It("doesn't create more than MaxTrackedReceivedAckRanges ranges", func() {
for i := protocol.PacketNumber(1); i <= protocol.MaxTrackedReceivedAckRanges; i++ {
err := hist.ReceivedPacket(2 * i)
Expect(err).ToNot(HaveOccurred())
}
err := hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 2)
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
})
It("doesn't consider already deleted ranges for MaxTrackedReceivedAckRanges", func() {
for i := protocol.PacketNumber(1); i <= protocol.MaxTrackedReceivedAckRanges; i++ {
err := hist.ReceivedPacket(2 * i)
Expect(err).ToNot(HaveOccurred())
}
err := hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 2)
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
hist.DeleteBelow(protocol.MaxTrackedReceivedAckRanges) // deletes about half of the ranges
err = hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 4)
Expect(err).ToNot(HaveOccurred())
})
})
})
Context("ACK range export", func() {
It("returns nil if there are no ranges", func() {
Expect(hist.GetAckRanges()).To(BeNil())
})
It("gets a single ACK range", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
ackRanges := hist.GetAckRanges()
Expect(ackRanges).To(HaveLen(1))
Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5}))
})
It("gets multiple ACK ranges", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
hist.ReceivedPacket(6)
hist.ReceivedPacket(1)
hist.ReceivedPacket(11)
hist.ReceivedPacket(10)
hist.ReceivedPacket(2)
ackRanges := hist.GetAckRanges()
Expect(ackRanges).To(HaveLen(3))
Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 10, Largest: 11}))
Expect(ackRanges[1]).To(Equal(wire.AckRange{Smallest: 4, Largest: 6}))
Expect(ackRanges[2]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2}))
})
})
Context("Getting the highest ACK range", func() {
It("returns the zero value if there are no ranges", func() {
Expect(hist.GetHighestAckRange()).To(BeZero())
})
It("gets a single ACK range", func() {
hist.ReceivedPacket(4)
hist.ReceivedPacket(5)
Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 4, Largest: 5}))
})
It("gets the highest of multiple ACK ranges", func() {
hist.ReceivedPacket(3)
hist.ReceivedPacket(6)
hist.ReceivedPacket(7)
Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 6, Largest: 7}))
})
})
})

View File

@ -0,0 +1,36 @@
package ackhandler
import "github.com/lucas-clemente/quic-go/internal/wire"
// Returns a new slice with all non-retransmittable frames deleted.
func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
res := make([]wire.Frame, 0, len(fs))
for _, f := range fs {
if IsFrameRetransmittable(f) {
res = append(res, f)
}
}
return res
}
// IsFrameRetransmittable returns true if the frame should be retransmitted.
func IsFrameRetransmittable(f wire.Frame) bool {
switch f.(type) {
case *wire.StopWaitingFrame:
return false
case *wire.AckFrame:
return false
default:
return true
}
}
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
func HasRetransmittableFrames(fs []wire.Frame) bool {
for _, f := range fs {
if IsFrameRetransmittable(f) {
return true
}
}
return false
}

View File

@ -0,0 +1,45 @@
package ackhandler
import (
"reflect"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("retransmittable frames", func() {
for fl, el := range map[wire.Frame]bool{
&wire.AckFrame{}: false,
&wire.StopWaitingFrame{}: false,
&wire.BlockedFrame{}: true,
&wire.ConnectionCloseFrame{}: true,
&wire.GoawayFrame{}: true,
&wire.PingFrame{}: true,
&wire.RstStreamFrame{}: true,
&wire.StreamFrame{}: true,
&wire.MaxDataFrame{}: true,
&wire.MaxStreamDataFrame{}: true,
} {
f := fl
e := el
fName := reflect.ValueOf(f).Elem().Type().Name()
It("works for "+fName, func() {
Expect(IsFrameRetransmittable(f)).To(Equal(e))
})
It("stripping non-retransmittable frames works for "+fName, func() {
s := []wire.Frame{f}
if e {
Expect(stripNonRetransmittableFrames(s)).To(Equal([]wire.Frame{f}))
} else {
Expect(stripNonRetransmittableFrames(s)).To(BeEmpty())
}
})
It("HasRetransmittableFrames works for "+fName, func() {
Expect(HasRetransmittableFrames([]wire.Frame{f})).To(Equal(e))
})
}
})

View File

@ -0,0 +1,40 @@
package ackhandler
import "fmt"
// The SendMode says what kind of packets can be sent.
type SendMode uint8
const (
// SendNone means that no packets should be sent
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendRetransmission means that retransmissions should be sent
SendRetransmission
// SendRTO means that an RTO probe packet should be sent
SendRTO
// SendTLP means that a TLP probe packet should be sent
SendTLP
// SendAny means that any packet should be sent
SendAny
)
func (s SendMode) String() string {
switch s {
case SendNone:
return "none"
case SendAck:
return "ack"
case SendRetransmission:
return "retransmission"
case SendRTO:
return "rto"
case SendTLP:
return "tlp"
case SendAny:
return "any"
default:
return fmt.Sprintf("invalid send mode: %d", s)
}
}

View File

@ -0,0 +1,18 @@
package ackhandler
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Send Mode", func() {
It("has a string representation", func() {
Expect(SendNone.String()).To(Equal("none"))
Expect(SendAny.String()).To(Equal("any"))
Expect(SendAck.String()).To(Equal("ack"))
Expect(SendRTO.String()).To(Equal("rto"))
Expect(SendTLP.String()).To(Equal("tlp"))
Expect(SendRetransmission.String()).To(Equal("retransmission"))
Expect(SendMode(123).String()).To(Equal("invalid send mode: 123"))
})
})

View File

@ -0,0 +1,658 @@
package ackhandler
import (
"errors"
"fmt"
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// In fraction of an RTT.
timeReorderingFraction = 1.0 / 8
// defaultRTOTimeout is the RTO time on new connections
defaultRTOTimeout = 500 * time.Millisecond
// Minimum time in the future a tail loss probe alarm may be set for.
minTPLTimeout = 10 * time.Millisecond
// Maximum number of tail loss probes before an RTO fires.
maxTLPs = 2
// Minimum time in the future an RTO alarm may be set for.
minRTOTimeout = 200 * time.Millisecond
// maxRTOTimeout is the maximum RTO time
maxRTOTimeout = 60 * time.Second
)
type sentPacketHandler struct {
lastSentPacketNumber protocol.PacketNumber
lastSentRetransmittablePacketTime time.Time
lastSentHandshakePacketTime time.Time
nextPacketSendTime time.Time
skippedPackets []protocol.PacketNumber
largestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
lowestPacketNotConfirmedAcked protocol.PacketNumber
largestSentBeforeRTO protocol.PacketNumber
packetHistory *sentPacketHistory
stopWaitingManager stopWaitingManager
retransmissionQueue []*Packet
bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithm
rttStats *congestion.RTTStats
handshakeComplete bool
// The number of times the handshake packets have been retransmitted without receiving an ack.
handshakeCount uint32
// The number of times a TLP has been sent without receiving an ack.
tlpCount uint32
allowTLP bool
// The number of times an RTO has been sent without receiving an ack.
rtoCount uint32
// The number of RTO probe packets that should be sent.
numRTOs int
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
lossTime time.Time
// The alarm timeout
alarm time.Time
logger utils.Logger
version protocol.VersionNumber
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
false, /* don't use reno since chromium doesn't (why?) */
protocol.InitialCongestionWindow,
protocol.DefaultMaxCongestionWindow,
)
return &sentPacketHandler{
packetHistory: newSentPacketHistory(),
stopWaitingManager: stopWaitingManager{},
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
}
}
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
if p := h.packetHistory.FirstOutstanding(); p != nil {
return p.PacketNumber
}
return h.largestAcked + 1
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
queue = append(queue, packet)
}
}
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.packetHistory.Remove(p.PacketNumber)
}
h.retransmissionQueue = queue
h.handshakeComplete = true
}
func (h *sentPacketHandler) SentPacket(packet *Packet) {
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
h.packetHistory.SentPacket(packet)
h.updateLossDetectionAlarm()
}
}
func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
var p []*Packet
for _, packet := range packets {
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
p = append(p, packet)
}
}
h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf)
h.updateLossDetectionAlarm()
}
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.skippedPackets = append(h.skippedPackets, p)
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
h.skippedPackets = h.skippedPackets[1:]
}
}
h.lastSentPacketNumber = packet.PacketNumber
if len(packet.Frames) > 0 {
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
packet.largestAcked = ackFrame.LargestAcked()
}
}
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
h.lastSentHandshakePacketTime = packet.SendTime
}
h.lastSentRetransmittablePacketTime = packet.SendTime
packet.includedInBytesInFlight = true
h.bytesInFlight += packet.Length
packet.canBeRetransmitted = true
if h.numRTOs > 0 {
h.numRTOs--
}
h.allowTLP = false
}
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.SendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
return isRetransmittable
}
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
largestAcked := ackFrame.LargestAcked()
if largestAcked > h.lastSentPacketNumber {
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
}
// duplicate or out of order ACK
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
return nil
}
h.largestReceivedPacketWithAck = withPacketNumber
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
if h.skippedPacketsAcked(ackFrame) {
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
}
if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
h.congestion.MaybeExitSlowStart()
}
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
if err != nil {
return err
}
priorInFlight := h.bytesInFlight
for _, p := range ackedPackets {
if encLevel < p.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
}
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
if p.largestAcked != 0 {
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
}
if err := h.onPacketAcked(p, rcvTime); err != nil {
return err
}
if p.includedInBytesInFlight {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
}
}
if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil {
return err
}
h.updateLossDetectionAlarm()
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
return nil
}
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
return h.lowestPacketNotConfirmedAcked
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
var ackedPackets []*Packet
ackRangeIndex := 0
lowestAcked := ackFrame.LowestAcked()
largestAcked := ackFrame.LargestAcked()
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
// Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
return true, nil
}
// Break after largest acked is reached
if p.PacketNumber > largestAcked {
return false, nil
}
if ackFrame.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 {
ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
}
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
if p.PacketNumber > ackRange.Largest {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
ackedPackets = append(ackedPackets, p)
}
} else {
ackedPackets = append(ackedPackets, p)
}
return true, nil
})
if h.logger.Debug() && len(ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(ackedPackets))
for i, p := range ackedPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
}
return ackedPackets, err
}
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
return true
}
return false
}
func (h *sentPacketHandler) updateLossDetectionAlarm() {
// Cancel the alarm if no packets are outstanding
if !h.packetHistory.HasOutstandingPackets() {
h.alarm = time.Time{}
return
}
if h.packetHistory.HasOutstandingHandshakePackets() {
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO or TLP alarm
alarmDuration := h.computeRTOTimeout()
if h.tlpCount < maxTLPs {
tlpAlarm := h.computeTLPTimeout()
// if the RTO duration is shorter than the TLP duration, use the RTO duration
alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm)
}
h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration)
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error {
h.lossTime = time.Time{}
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
var lostPackets []*Packet
h.packetHistory.Iterate(func(packet *Packet) (bool, error) {
if packet.PacketNumber > h.largestAcked {
return false, nil
}
timeSinceSent := now.Sub(packet.SendTime)
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, packet)
} else if h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent)
}
// Note: This conditional is only entered once per call
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
}
return true, nil
})
if h.logger.Debug() && len(lostPackets) > 0 {
pns := make([]protocol.PacketNumber, len(lostPackets))
for i, p := range lostPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns)
}
for _, p := range lostPackets {
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
if p.canBeRetransmitted {
// queue the packet for retransmission, and report the loss to the congestion controller
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
h.packetHistory.Remove(p.PacketNumber)
}
return nil
}
func (h *sentPacketHandler) OnAlarm() error {
// When all outstanding are acknowledged, the alarm is canceled in
// updateLossDetectionAlarm. This doesn't reset the timer in the session though.
// When OnAlarm is called, we therefore need to make sure that there are
// actually packets outstanding.
if h.packetHistory.HasOutstandingPackets() {
if err := h.onVerifiedAlarm(); err != nil {
return err
}
}
h.updateLossDetectionAlarm()
return nil
}
func (h *sentPacketHandler) onVerifiedAlarm() error {
var err error
if h.packetHistory.HasOutstandingHandshakePackets() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in handshake mode. Handshake count: %d", h.handshakeCount)
}
h.handshakeCount++
err = h.queueHandshakePacketsForRetransmission()
} else if !h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", h.lossTime)
}
// Early retransmit or time loss detection
err = h.detectLostPackets(time.Now(), h.bytesInFlight)
} else if h.tlpCount < maxTLPs { // TLP
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in TLP mode. TLP count: %d", h.tlpCount)
}
h.allowTLP = true
h.tlpCount++
} else { // RTO
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in RTO mode. RTO count: %d", h.rtoCount)
}
if h.rtoCount == 0 {
h.largestSentBeforeRTO = h.lastSentPacketNumber
}
h.rtoCount++
h.numRTOs += 2
}
return err
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
// This happens if a packet and its retransmissions is acked in the same ACK.
// As soon as we process the first one, this will remove all the retransmissions,
// so we won't find the retransmitted packet number later.
if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil {
return nil
}
// only report the acking of this packet to the congestion controller if:
// * it is a retransmittable packet
// * this packet wasn't retransmitted yet
if p.isRetransmission {
// that the parent doesn't exist is expected to happen every time the original packet was already acked
if parent := h.packetHistory.GetPacket(p.retransmissionOf); parent != nil {
if len(parent.retransmittedAs) == 1 {
parent.retransmittedAs = nil
} else {
// remove this packet from the slice of retransmission
retransmittedAs := make([]protocol.PacketNumber, 0, len(parent.retransmittedAs)-1)
for _, pn := range parent.retransmittedAs {
if pn != p.PacketNumber {
retransmittedAs = append(retransmittedAs, pn)
}
}
parent.retransmittedAs = retransmittedAs
}
}
}
// this also applies to packets that have been retransmitted as probe packets
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
if h.rtoCount > 0 {
h.verifyRTO(p.PacketNumber)
}
if err := h.stopRetransmissionsFor(p); err != nil {
return err
}
h.rtoCount = 0
h.tlpCount = 0
h.handshakeCount = 0
return h.packetHistory.Remove(p.PacketNumber)
}
func (h *sentPacketHandler) stopRetransmissionsFor(p *Packet) error {
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
for _, r := range p.retransmittedAs {
packet := h.packetHistory.GetPacket(r)
if packet == nil {
return fmt.Errorf("sent packet handler BUG: marking packet as not retransmittable %d (retransmission of %d) not found in history", r, p.PacketNumber)
}
h.stopRetransmissionsFor(packet)
}
return nil
}
func (h *sentPacketHandler) verifyRTO(pn protocol.PacketNumber) {
if pn <= h.largestSentBeforeRTO {
h.logger.Debugf("Spurious RTO detected. Received an ACK for %#x (largest sent before RTO: %#x)", pn, h.largestSentBeforeRTO)
// Replace SRTT with latest_rtt and increase the variance to prevent
// a spurious RTO from happening again.
h.rttStats.ExpireSmoothedMetrics()
return
}
h.logger.Debugf("RTO verified. Received an ACK for %#x (largest sent before RTO: %#x", pn, h.largestSentBeforeRTO)
h.congestion.OnRetransmissionTimeout(true)
}
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 {
return nil
}
packet := h.retransmissionQueue[0]
// Shift the slice and don't retain anything that isn't needed.
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
return packet
}
func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
if len(h.retransmissionQueue) == 0 {
p := h.packetHistory.FirstOutstanding()
if p == nil {
return nil, errors.New("cannot dequeue a probe packet. No outstanding packets")
}
if err := h.queuePacketForRetransmission(p); err != nil {
return nil, err
}
}
return h.DequeuePacketForRetransmission(), nil
}
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
return h.stopWaitingManager.GetStopWaitingFrame(force)
}
func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
// Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
// but still allow sending of retransmissions and ACKs.
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
}
return SendNone
}
if h.allowTLP {
return SendTLP
}
if h.numRTOs > 0 {
return SendRTO
}
// Only send ACKs if we're congestion limited.
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
if h.logger.Debug() {
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
}
return SendAck
}
// Send retransmissions first, if there are any.
if len(h.retransmissionQueue) > 0 {
return SendRetransmission
}
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
}
return SendAck
}
return SendAny
}
func (h *sentPacketHandler) TimeUntilSend() time.Time {
return h.nextPacketSendTime
}
func (h *sentPacketHandler) ShouldSendNumPackets() int {
if h.numRTOs > 0 {
// RTO probes should not be paced, but must be sent immediately.
return h.numRTOs
}
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
if delay == 0 || delay > protocol.MinPacingDelay {
return 1
}
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
return nil
}
func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
if !p.canBeRetransmitted {
return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber)
}
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
h.retransmissionQueue = append(h.retransmissionQueue, p)
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
return nil
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
// exponential backoff
// There's an implicit limit to this set by the handshake timeout.
return duration << h.handshakeCount
}
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {
// TODO(#1236): include the max_ack_delay
return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout)
}
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
var rto time.Duration
rtt := h.rttStats.SmoothedRTT()
if rtt == 0 {
rto = defaultRTOTimeout
} else {
rto = rtt + 4*h.rttStats.MeanDeviation()
}
rto = utils.MaxDuration(rto, minRTOTimeout)
// Exponential backoff
rto = rto << h.rtoCount
return utils.MinDuration(rto, maxRTOTimeout)
}
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
for _, p := range h.skippedPackets {
if ackFrame.AcksPacket(p) {
return true
}
}
return false
}
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
lowestUnacked := h.lowestUnacked()
deleteIndex := 0
for i, p := range h.skippedPackets {
if p < lowestUnacked {
deleteIndex = i + 1
}
}
h.skippedPackets = h.skippedPackets[deleteIndex:]
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,168 @@
package ackhandler
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
numOutstandingPackets int
numOutstandingHandshakePackets int
firstOutstanding *PacketElement
}
func newSentPacketHistory() *sentPacketHistory {
return &sentPacketHistory{
packetList: NewPacketList(),
packetMap: make(map[protocol.PacketNumber]*PacketElement),
}
}
func (h *sentPacketHistory) SentPacket(p *Packet) {
h.sentPacketImpl(p)
}
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
if h.firstOutstanding == nil {
h.firstOutstanding = el
}
if p.canBeRetransmitted {
h.numOutstandingPackets++
if p.EncryptionLevel < protocol.EncryptionForwardSecure {
h.numOutstandingHandshakePackets++
}
}
return el
}
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
retransmission, ok := h.packetMap[retransmissionOf]
// The retransmitted packet is not present anymore.
// This can happen if it was acked in between dequeueing of the retransmission and sending.
// Just treat the retransmissions as normal packets.
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
if !ok {
for _, packet := range packets {
h.sentPacketImpl(packet)
}
return
}
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
for i, packet := range packets {
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
el := h.sentPacketImpl(packet)
el.Value.isRetransmission = true
el.Value.retransmissionOf = retransmissionOf
}
}
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
if el, ok := h.packetMap[p]; ok {
return &el.Value
}
return nil
}
// Iterate iterates through all packets.
// The callback must not modify the history.
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
cont := true
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
var err error
cont, err = cb(&el.Value)
if err != nil {
return err
}
}
return nil
}
// FirstOutStanding returns the first outstanding packet.
// It must not be modified (e.g. retransmitted).
// Use DequeueFirstPacketForRetransmission() to retransmit it.
func (h *sentPacketHistory) FirstOutstanding() *Packet {
if h.firstOutstanding == nil {
return nil
}
return &h.firstOutstanding.Value
}
// QueuePacketForRetransmission marks a packet for retransmission.
// A packet can only be queued once.
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
el, ok := h.packetMap[pn]
if !ok {
return fmt.Errorf("sent packet history: packet %d not found", pn)
}
if el.Value.canBeRetransmitted {
h.numOutstandingPackets--
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
}
el.Value.canBeRetransmitted = false
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
return nil
}
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
// This is necessary every time the first outstanding packet is deleted or retransmitted.
func (h *sentPacketHistory) readjustFirstOutstanding() {
el := h.firstOutstanding.Next()
for el != nil && !el.Value.canBeRetransmitted {
el = el.Next()
}
h.firstOutstanding = el
}
func (h *sentPacketHistory) Len() int {
return len(h.packetMap)
}
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
el, ok := h.packetMap[p]
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
}
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
if el.Value.canBeRetransmitted {
h.numOutstandingPackets--
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
}
h.packetList.Remove(el)
delete(h.packetMap, p)
return nil
}
func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstandingPackets > 0
}
func (h *sentPacketHistory) HasOutstandingHandshakePackets() bool {
return h.numOutstandingHandshakePackets > 0
}

View File

@ -0,0 +1,297 @@
package ackhandler
import (
"errors"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("SentPacketHistory", func() {
var hist *sentPacketHistory
expectInHistory := func(packetNumbers []protocol.PacketNumber) {
ExpectWithOffset(1, hist.packetMap).To(HaveLen(len(packetNumbers)))
ExpectWithOffset(1, hist.packetList.Len()).To(Equal(len(packetNumbers)))
i := 0
hist.Iterate(func(p *Packet) (bool, error) {
pn := packetNumbers[i]
ExpectWithOffset(1, p.PacketNumber).To(Equal(pn))
ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn))
i++
return true, nil
})
}
BeforeEach(func() {
hist = newSentPacketHistory()
})
It("saves sent packets", func() {
hist.SentPacket(&Packet{PacketNumber: 1})
hist.SentPacket(&Packet{PacketNumber: 3})
hist.SentPacket(&Packet{PacketNumber: 4})
expectInHistory([]protocol.PacketNumber{1, 3, 4})
})
It("gets the length", func() {
hist.SentPacket(&Packet{PacketNumber: 1})
hist.SentPacket(&Packet{PacketNumber: 10})
Expect(hist.Len()).To(Equal(2))
})
Context("getting the first outstanding packet", func() {
It("gets nil, if there are no packets", func() {
Expect(hist.FirstOutstanding()).To(BeNil())
})
It("gets the first outstanding packet", func() {
hist.SentPacket(&Packet{PacketNumber: 2})
hist.SentPacket(&Packet{PacketNumber: 3})
front := hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2)))
})
It("gets the second packet if the first one is retransmitted", func() {
hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true})
front := hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// Queue the first packet for retransmission.
// The first outstanding packet should now be 3.
err := hist.MarkCannotBeRetransmitted(1)
Expect(err).ToNot(HaveOccurred())
front = hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(3)))
})
It("gets the third packet if the first two are retransmitted", func() {
hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true})
hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true})
front := hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// Queue the second packet for retransmission.
// The first outstanding packet should still be 3.
err := hist.MarkCannotBeRetransmitted(3)
Expect(err).ToNot(HaveOccurred())
front = hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
// Queue the first packet for retransmission.
// The first outstanding packet should still be 4.
err = hist.MarkCannotBeRetransmitted(1)
Expect(err).ToNot(HaveOccurred())
front = hist.FirstOutstanding()
Expect(front).ToNot(BeNil())
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(4)))
})
})
It("gets a packet by packet number", func() {
p := &Packet{PacketNumber: 2}
hist.SentPacket(p)
Expect(hist.GetPacket(2)).To(Equal(p))
})
It("returns nil if the packet doesn't exist", func() {
Expect(hist.GetPacket(1337)).To(BeNil())
})
It("removes packets", func() {
hist.SentPacket(&Packet{PacketNumber: 1})
hist.SentPacket(&Packet{PacketNumber: 4})
hist.SentPacket(&Packet{PacketNumber: 8})
err := hist.Remove(4)
Expect(err).ToNot(HaveOccurred())
expectInHistory([]protocol.PacketNumber{1, 8})
})
It("errors when trying to remove a non existing packet", func() {
hist.SentPacket(&Packet{PacketNumber: 1})
err := hist.Remove(2)
Expect(err).To(MatchError("packet 2 not found in sent packet history"))
})
Context("iterating", func() {
BeforeEach(func() {
hist.SentPacket(&Packet{PacketNumber: 10})
hist.SentPacket(&Packet{PacketNumber: 14})
hist.SentPacket(&Packet{PacketNumber: 18})
})
It("iterates over all packets", func() {
var iterations []protocol.PacketNumber
err := hist.Iterate(func(p *Packet) (bool, error) {
iterations = append(iterations, p.PacketNumber)
return true, nil
})
Expect(err).ToNot(HaveOccurred())
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14, 18}))
})
It("stops iterating", func() {
var iterations []protocol.PacketNumber
err := hist.Iterate(func(p *Packet) (bool, error) {
iterations = append(iterations, p.PacketNumber)
return p.PacketNumber != 14, nil
})
Expect(err).ToNot(HaveOccurred())
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14}))
})
It("returns the error", func() {
testErr := errors.New("test error")
var iterations []protocol.PacketNumber
err := hist.Iterate(func(p *Packet) (bool, error) {
iterations = append(iterations, p.PacketNumber)
if p.PacketNumber == 14 {
return false, testErr
}
return true, nil
})
Expect(err).To(MatchError(testErr))
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14}))
})
})
Context("retransmissions", func() {
BeforeEach(func() {
for i := protocol.PacketNumber(1); i <= 5; i++ {
hist.SentPacket(&Packet{PacketNumber: i})
}
})
It("errors if the packet doesn't exist", func() {
err := hist.MarkCannotBeRetransmitted(100)
Expect(err).To(MatchError("sent packet history: packet 100 not found"))
})
It("adds a sent packets as a retransmission", func() {
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}}, 2)
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13})
Expect(hist.GetPacket(13).isRetransmission).To(BeTrue())
Expect(hist.GetPacket(13).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
Expect(hist.GetPacket(2).retransmittedAs).To(Equal([]protocol.PacketNumber{13}))
})
It("adds multiple packets sent as a retransmission", func() {
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}, {PacketNumber: 15}}, 2)
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13, 15})
Expect(hist.GetPacket(13).isRetransmission).To(BeTrue())
Expect(hist.GetPacket(13).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
Expect(hist.GetPacket(15).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
Expect(hist.GetPacket(2).retransmittedAs).To(Equal([]protocol.PacketNumber{13, 15}))
})
It("adds a packet as a normal packet if the retransmitted packet doesn't exist", func() {
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}}, 7)
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13})
Expect(hist.GetPacket(13).isRetransmission).To(BeFalse())
Expect(hist.GetPacket(13).retransmissionOf).To(BeZero())
})
})
Context("outstanding packets", func() {
It("says if it has outstanding handshake packets", func() {
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
hist.SentPacket(&Packet{
EncryptionLevel: protocol.EncryptionUnencrypted,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
})
It("says if it has outstanding packets", func() {
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
hist.SentPacket(&Packet{
EncryptionLevel: protocol.EncryptionForwardSecure,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
Expect(hist.HasOutstandingPackets()).To(BeTrue())
})
It("doesn't consider non-retransmittable packets as outstanding", func() {
hist.SentPacket(&Packet{
EncryptionLevel: protocol.EncryptionUnencrypted,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
It("accounts for deleted handshake packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 5,
EncryptionLevel: protocol.EncryptionSecure,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
err := hist.Remove(5)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
})
It("accounts for deleted packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.EncryptionForwardSecure,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingPackets()).To(BeTrue())
err := hist.Remove(10)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
It("doesn't count handshake packets marked as non-retransmittable", func() {
hist.SentPacket(&Packet{
PacketNumber: 5,
EncryptionLevel: protocol.EncryptionUnencrypted,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
err := hist.MarkCannotBeRetransmitted(5)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
})
It("doesn't count packets marked as non-retransmittable", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.EncryptionForwardSecure,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingPackets()).To(BeTrue())
err := hist.MarkCannotBeRetransmitted(10)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
It("counts the number of packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.EncryptionForwardSecure,
canBeRetransmitted: true,
})
hist.SentPacket(&Packet{
PacketNumber: 11,
EncryptionLevel: protocol.EncryptionForwardSecure,
canBeRetransmitted: true,
})
err := hist.Remove(11)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingPackets()).To(BeTrue())
err = hist.Remove(10)
Expect(err).ToNot(HaveOccurred())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
})
})
})

View File

@ -0,0 +1,43 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
type stopWaitingManager struct {
largestLeastUnackedSent protocol.PacketNumber
nextLeastUnacked protocol.PacketNumber
lastStopWaitingFrame *wire.StopWaitingFrame
}
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
if force {
return s.lastStopWaitingFrame
}
return nil
}
s.largestLeastUnackedSent = s.nextLeastUnacked
swf := &wire.StopWaitingFrame{
LeastUnacked: s.nextLeastUnacked,
}
s.lastStopWaitingFrame = swf
return swf
}
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
largestAcked := ack.LargestAcked()
if largestAcked >= s.nextLeastUnacked {
s.nextLeastUnacked = largestAcked + 1
}
}
func (s *stopWaitingManager) QueuedRetransmissionForPacketNumber(p protocol.PacketNumber) {
if p >= s.nextLeastUnacked {
s.nextLeastUnacked = p + 1
}
}

View File

@ -0,0 +1,55 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("StopWaitingManager", func() {
var manager *stopWaitingManager
BeforeEach(func() {
manager = &stopWaitingManager{}
})
It("returns nil in the beginning", func() {
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
Expect(manager.GetStopWaitingFrame(true)).To(BeNil())
})
It("returns a StopWaitingFrame, when a new ACK arrives", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
})
It("does not decrease the LeastUnacked", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 9}}})
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
})
It("does not send the same StopWaitingFrame twice", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
})
It("gets the same StopWaitingFrame twice, if forced", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
})
It("increases the LeastUnacked when a retransmission is queued", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
manager.QueuedRetransmissionForPacketNumber(20)
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 21}))
})
It("does not decrease the LeastUnacked when a retransmission is queued", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
manager.QueuedRetransmissionForPacketNumber(9)
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
})
})

View File

@ -0,0 +1,22 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// Bandwidth of a connection
type Bandwidth uint64
const (
// BitsPerSecond is 1 bit per second
BitsPerSecond Bandwidth = 1
// BytesPerSecond is 1 byte per second
BytesPerSecond = 8 * BitsPerSecond
)
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth {
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
}

View File

@ -0,0 +1,14 @@
package congestion
import (
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Bandwidth", func() {
It("converts from time delta", func() {
Expect(BandwidthFromDelta(1, time.Millisecond)).To(Equal(1000 * BytesPerSecond))
})
})

View File

@ -0,0 +1,18 @@
package congestion
import "time"
// A Clock returns the current time
type Clock interface {
Now() time.Time
}
// DefaultClock implements the Clock interface using the Go stdlib clock.
type DefaultClock struct{}
var _ Clock = DefaultClock{}
// Now gets the current time
func (DefaultClock) Now() time.Time {
return time.Now()
}

View File

@ -0,0 +1,13 @@
package congestion
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestCongestion(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Congestion Suite")
}

View File

@ -0,0 +1,210 @@
package congestion
import (
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// This cubic implementation is based on the one found in Chromiums's QUIC
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
// Constants based on TCP defaults.
// The following constants are in 2^10 fractions of a second instead of ms to
// allow a 10 shift right to divide.
// 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling round trip time.
const cubeScale = 40
const cubeCongestionWindowScale = 410
const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS
const defaultNumConnections = 2
// Default Cubic backoff factor
const beta float32 = 0.7
// Additional backoff factor when loss occurs in the concave part of the Cubic
// curve. This additional backoff factor is expected to give up bandwidth to
// new concurrent flows and speed up convergence.
const betaLastMax float32 = 0.85
// Cubic implements the cubic algorithm from TCP
type Cubic struct {
clock Clock
// Number of connections to simulate.
numConnections int
// Time when this cycle started, after last loss event.
epoch time.Time
// Max congestion window used just before last loss event.
// Note: to improve fairness to other streams an additional back off is
// applied to this value if the new value is below our latest value.
lastMaxCongestionWindow protocol.ByteCount
// Number of acked bytes since the cycle started (epoch).
ackedBytesCount protocol.ByteCount
// TCP Reno equivalent congestion window in packets.
estimatedTCPcongestionWindow protocol.ByteCount
// Origin point of cubic function.
originPointCongestionWindow protocol.ByteCount
// Time to origin point of cubic function in 2^10 fractions of a second.
timeToOriginPoint uint32
// Last congestion window in packets computed by cubic function.
lastTargetCongestionWindow protocol.ByteCount
}
// NewCubic returns a new Cubic instance
func NewCubic(clock Clock) *Cubic {
c := &Cubic{
clock: clock,
numConnections: defaultNumConnections,
}
c.Reset()
return c
}
// Reset is called after a timeout to reset the cubic state
func (c *Cubic) Reset() {
c.epoch = time.Time{}
c.lastMaxCongestionWindow = 0
c.ackedBytesCount = 0
c.estimatedTCPcongestionWindow = 0
c.originPointCongestionWindow = 0
c.timeToOriginPoint = 0
c.lastTargetCongestionWindow = 0
}
func (c *Cubic) alpha() float32 {
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
// We derive the equivalent alpha for an N-connection emulation as:
b := c.beta()
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
}
func (c *Cubic) beta() float32 {
// kNConnectionBeta is the backoff factor after loss for our N-connection
// emulation, which emulates the effective backoff of an ensemble of N
// TCP-Reno connections on a single loss event. The effective multiplier is
// computed as:
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
}
func (c *Cubic) betaLastMax() float32 {
// betaLastMax is the additional backoff factor after loss for our
// N-connection emulation, which emulates the additional backoff of
// an ensemble of N TCP-Reno connections on a single loss event. The
// effective multiplier is computed as:
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
}
// OnApplicationLimited is called on ack arrival when sender is unable to use
// the available congestion window. Resets Cubic state during quiescence.
func (c *Cubic) OnApplicationLimited() {
// When sender is not using the available congestion window, the window does
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
// using the entire window during the time since the beginning of the current
// "epoch" (the end of the last loss recovery period). Since
// application-limited periods break this assumption, we reset the epoch when
// in such a period. This reset effectively freezes congestion window growth
// through application-limited periods and allows Cubic growth to continue
// when the entire window is being used.
c.epoch = time.Time{}
}
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
// a loss event. Returns the new congestion window in packets. The new
// congestion window is a multiplicative decrease of our current window.
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow {
// We never reached the old max, so assume we are competing with another
// flow. Use our extra back off factor to allow the other flow to go up.
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
} else {
c.lastMaxCongestionWindow = currentCongestionWindow
}
c.epoch = time.Time{} // Reset time.
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
}
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
// Returns the new congestion window in packets. The new congestion window
// follows a cubic function that depends on the time passed since last
// packet loss.
func (c *Cubic) CongestionWindowAfterAck(
ackedBytes protocol.ByteCount,
currentCongestionWindow protocol.ByteCount,
delayMin time.Duration,
eventTime time.Time,
) protocol.ByteCount {
c.ackedBytesCount += ackedBytes
if c.epoch.IsZero() {
// First ACK after a loss event.
c.epoch = eventTime // Start of epoch.
c.ackedBytesCount = ackedBytes // Reset count.
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
c.estimatedTCPcongestionWindow = currentCongestionWindow
if c.lastMaxCongestionWindow <= currentCongestionWindow {
c.timeToOriginPoint = 0
c.originPointCongestionWindow = currentCongestionWindow
} else {
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
c.originPointCongestionWindow = c.lastMaxCongestionWindow
}
}
// Change the time unit from microseconds to 2^10 fractions per second. Take
// the round trip time in account. This is done to allow us to use shift as a
// divide operator.
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
// Right-shifts of negative, signed numbers have implementation-dependent
// behavior, so force the offset to be positive, as is done in the kernel.
offset := int64(c.timeToOriginPoint) - elapsedTime
if offset < 0 {
offset = -offset
}
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale
var targetCongestionWindow protocol.ByteCount
if elapsedTime > int64(c.timeToOriginPoint) {
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
} else {
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
}
// Limit the CWND increase to half the acked bytes.
targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
// Increase the window by approximately Alpha * 1 MSS of bytes every
// time we ack an estimated tcp window of bytes. For small
// congestion windows (less than 25), the formula below will
// increase slightly slower than linearly per estimated tcp window
// of bytes.
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow))
c.ackedBytesCount = 0
// We have a new cubic congestion window.
c.lastTargetCongestionWindow = targetCongestionWindow
// Compute target congestion_window based on cubic target and estimated TCP
// congestion_window, use highest (fastest).
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
targetCongestionWindow = c.estimatedTCPcongestionWindow
}
return targetCongestionWindow
}
// SetNumConnections sets the number of emulated connections
func (c *Cubic) SetNumConnections(n int) {
c.numConnections = n
}

View File

@ -0,0 +1,318 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const (
maxBurstBytes = 3 * protocol.DefaultTCPMSS
renoBeta float32 = 0.7 // Reno backoff factor.
defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
)
type cubicSender struct {
hybridSlowStart HybridSlowStart
prr PrrSender
rttStats *RTTStats
stats connectionStats
cubic *Cubic
reno bool
// Track the largest packet that has been sent.
largestSentPacketNumber protocol.PacketNumber
// Track the largest packet that has been acked.
largestAckedPacketNumber protocol.PacketNumber
// Track the largest packet number outstanding when a CWND cutback occurs.
largestSentAtLastCutback protocol.PacketNumber
// Whether the last loss event caused us to exit slowstart.
// Used for stats collection of slowstartPacketsLost
lastCutbackExitedSlowstart bool
// When true, exit slow start with large cutback of congestion window.
slowStartLargeReduction bool
// Congestion window in packets.
congestionWindow protocol.ByteCount
// Minimum congestion window in packets.
minCongestionWindow protocol.ByteCount
// Maximum congestion window.
maxCongestionWindow protocol.ByteCount
// Slow start congestion window in bytes, aka ssthresh.
slowstartThreshold protocol.ByteCount
// Number of connections to simulate.
numConnections int
// ACK counter for the Reno implementation.
numAckedPackets uint64
initialCongestionWindow protocol.ByteCount
initialMaxCongestionWindow protocol.ByteCount
minSlowStartExitWindow protocol.ByteCount
}
var _ SendAlgorithm = &cubicSender{}
var _ SendAlgorithmWithDebugInfo = &cubicSender{}
// NewCubicSender makes a new cubic sender
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo {
return &cubicSender{
rttStats: rttStats,
initialCongestionWindow: initialCongestionWindow,
initialMaxCongestionWindow: initialMaxCongestionWindow,
congestionWindow: initialCongestionWindow,
minCongestionWindow: defaultMinimumCongestionWindow,
slowstartThreshold: initialMaxCongestionWindow,
maxCongestionWindow: initialMaxCongestionWindow,
numConnections: defaultNumConnections,
cubic: NewCubic(clock),
reno: reno,
}
}
// TimeUntilSend returns when the next packet should be sent.
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
if c.InRecovery() {
// PRR is used when in recovery.
if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
return 0
}
}
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow())
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
delay = delay * 8 / 5
}
return delay
}
func (c *cubicSender) OnPacketSent(
sentTime time.Time,
bytesInFlight protocol.ByteCount,
packetNumber protocol.PacketNumber,
bytes protocol.ByteCount,
isRetransmittable bool,
) {
if !isRetransmittable {
return
}
if c.InRecovery() {
// PRR is used when in recovery.
c.prr.OnPacketSent(bytes)
}
c.largestSentPacketNumber = packetNumber
c.hybridSlowStart.OnPacketSent(packetNumber)
}
func (c *cubicSender) InRecovery() bool {
return c.largestAckedPacketNumber <= c.largestSentAtLastCutback && c.largestAckedPacketNumber != 0
}
func (c *cubicSender) InSlowStart() bool {
return c.GetCongestionWindow() < c.GetSlowStartThreshold()
}
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
return c.congestionWindow
}
func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
return c.slowstartThreshold
}
func (c *cubicSender) ExitSlowstart() {
c.slowstartThreshold = c.congestionWindow
}
func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
return c.slowstartThreshold
}
func (c *cubicSender) MaybeExitSlowStart() {
if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
c.ExitSlowstart()
}
}
func (c *cubicSender) OnPacketAcked(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() {
// PRR is used when in recovery.
c.prr.OnPacketAcked(ackedBytes)
return
}
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
if c.InSlowStart() {
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
}
}
func (c *cubicSender) OnPacketLost(
packetNumber protocol.PacketNumber,
lostBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback {
if c.lastCutbackExitedSlowstart {
c.stats.slowstartPacketsLost++
c.stats.slowstartBytesLost += lostBytes
if c.slowStartLargeReduction {
// Reduce congestion window by lost_bytes for every loss.
c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
c.slowstartThreshold = c.congestionWindow
}
}
return
}
c.lastCutbackExitedSlowstart = c.InSlowStart()
if c.InSlowStart() {
c.stats.slowstartPacketsLost++
}
c.prr.OnPacketLost(priorInFlight)
// TODO(chromium): Separate out all of slow start into a separate class.
if c.slowStartLargeReduction && c.InSlowStart() {
if c.congestionWindow >= 2*c.initialCongestionWindow {
c.minSlowStartExitWindow = c.congestionWindow / 2
}
c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS
} else if c.reno {
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
} else {
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
}
if c.congestionWindow < c.minCongestionWindow {
c.congestionWindow = c.minCongestionWindow
}
c.slowstartThreshold = c.congestionWindow
c.largestSentAtLastCutback = c.largestSentPacketNumber
// reset packet count from congestion avoidance mode. We start
// counting again when we're out of recovery.
c.numAckedPackets = 0
}
func (c *cubicSender) RenoBeta() float32 {
// kNConnectionBeta is the backoff factor after loss for our N-connection
// emulation, which emulates the effective backoff of an ensemble of N
// TCP-Reno connections on a single loss event. The effective multiplier is
// computed as:
return (float32(c.numConnections) - 1. + renoBeta) / float32(c.numConnections)
}
// Called when we receive an ack. Normal TCP tracks how many packets one ack
// represents, but quic has a separate ack for each packet.
func (c *cubicSender) maybeIncreaseCwnd(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
// Do not increase the congestion window unless the sender is close to using
// the current window.
if !c.isCwndLimited(priorInFlight) {
c.cubic.OnApplicationLimited()
return
}
if c.congestionWindow >= c.maxCongestionWindow {
return
}
if c.InSlowStart() {
// TCP slow start, exponential growth, increase by one for each ACK.
c.congestionWindow += protocol.DefaultTCPMSS
return
}
// Congestion avoidance
if c.reno {
// Classic Reno congestion avoidance.
c.numAckedPackets++
// Divide by num_connections to smoothly increase the CWND at a faster
// rate than conventional Reno.
if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
c.congestionWindow += protocol.DefaultTCPMSS
c.numAckedPackets = 0
}
} else {
c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
}
}
func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
congestionWindow := c.GetCongestionWindow()
if bytesInFlight >= congestionWindow {
return true
}
availableBytes := congestionWindow - bytesInFlight
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
return slowStartLimited || availableBytes <= maxBurstBytes
}
// BandwidthEstimate returns the current bandwidth estimate
func (c *cubicSender) BandwidthEstimate() Bandwidth {
srtt := c.rttStats.SmoothedRTT()
if srtt == 0 {
// If we haven't measured an rtt, the bandwidth estimate is unknown.
return 0
}
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
}
// HybridSlowStart returns the hybrid slow start instance for testing
func (c *cubicSender) HybridSlowStart() *HybridSlowStart {
return &c.hybridSlowStart
}
// SetNumEmulatedConnections sets the number of emulated connections
func (c *cubicSender) SetNumEmulatedConnections(n int) {
c.numConnections = utils.Max(n, 1)
c.cubic.SetNumConnections(c.numConnections)
}
// OnRetransmissionTimeout is called on an retransmission timeout
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
c.largestSentAtLastCutback = 0
if !packetsRetransmitted {
return
}
c.hybridSlowStart.Restart()
c.cubic.Reset()
c.slowstartThreshold = c.congestionWindow / 2
c.congestionWindow = c.minCongestionWindow
}
// OnConnectionMigration is called when the connection is migrated (?)
func (c *cubicSender) OnConnectionMigration() {
c.hybridSlowStart.Restart()
c.prr = PrrSender{}
c.largestSentPacketNumber = 0
c.largestAckedPacketNumber = 0
c.largestSentAtLastCutback = 0
c.lastCutbackExitedSlowstart = false
c.cubic.Reset()
c.numAckedPackets = 0
c.congestionWindow = c.initialCongestionWindow
c.slowstartThreshold = c.initialMaxCongestionWindow
c.maxCongestionWindow = c.initialMaxCongestionWindow
}
// SetSlowStartLargeReduction allows enabling the SSLR experiment
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
c.slowStartLargeReduction = enabled
}

View File

@ -0,0 +1,640 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const initialCongestionWindowPackets = 10
const defaultWindowTCP = protocol.ByteCount(initialCongestionWindowPackets) * protocol.DefaultTCPMSS
type mockClock time.Time
func (c *mockClock) Now() time.Time {
return time.Time(*c)
}
func (c *mockClock) Advance(d time.Duration) {
*c = mockClock(time.Time(*c).Add(d))
}
const MaxCongestionWindow protocol.ByteCount = 200 * protocol.DefaultTCPMSS
var _ = Describe("Cubic Sender", func() {
var (
sender SendAlgorithmWithDebugInfo
clock mockClock
bytesInFlight protocol.ByteCount
packetNumber protocol.PacketNumber
ackedPacketNumber protocol.PacketNumber
rttStats *RTTStats
)
BeforeEach(func() {
bytesInFlight = 0
packetNumber = 1
ackedPacketNumber = 0
clock = mockClock{}
rttStats = NewRTTStats()
sender = NewCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*protocol.DefaultTCPMSS, MaxCongestionWindow)
})
canSend := func() bool {
return bytesInFlight < sender.GetCongestionWindow()
}
SendAvailableSendWindowLen := func(packetLength protocol.ByteCount) int {
packetsSent := 0
for canSend() {
sender.OnPacketSent(clock.Now(), bytesInFlight, packetNumber, packetLength, true)
packetNumber++
packetsSent++
bytesInFlight += packetLength
}
return packetsSent
}
// Normal is that TCP acks every other segment.
AckNPackets := func(n int) {
rttStats.UpdateRTT(60*time.Millisecond, 0, clock.Now())
sender.MaybeExitSlowStart()
for i := 0; i < n; i++ {
ackedPacketNumber++
sender.OnPacketAcked(ackedPacketNumber, protocol.DefaultTCPMSS, bytesInFlight, clock.Now())
}
bytesInFlight -= protocol.ByteCount(n) * protocol.DefaultTCPMSS
clock.Advance(time.Millisecond)
}
LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) {
for i := 0; i < n; i++ {
ackedPacketNumber++
sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight)
}
bytesInFlight -= protocol.ByteCount(n) * packetLength
}
// Does not increment acked_packet_number_.
LosePacket := func(number protocol.PacketNumber) {
sender.OnPacketLost(number, protocol.DefaultTCPMSS, bytesInFlight)
bytesInFlight -= protocol.DefaultTCPMSS
}
SendAvailableSendWindow := func() int { return SendAvailableSendWindowLen(protocol.DefaultTCPMSS) }
LoseNPackets := func(n int) { LoseNPacketsLen(n, protocol.DefaultTCPMSS) }
It("has the right values at startup", func() {
// At startup make sure we are at the default.
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
// Make sure we can send.
Expect(sender.TimeUntilSend(0)).To(BeZero())
Expect(canSend()).To(BeTrue())
// And that window is un-affected.
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
// Fill the send window with data, then verify that we can't send.
SendAvailableSendWindow()
Expect(canSend()).To(BeFalse())
})
It("paces", func() {
clock.Advance(time.Hour)
// Fill the send window with data, then verify that we can't send.
SendAvailableSendWindow()
AckNPackets(1)
delay := sender.TimeUntilSend(bytesInFlight)
Expect(delay).ToNot(BeZero())
Expect(delay).ToNot(Equal(utils.InfDuration))
})
It("application limited slow start", func() {
// Send exactly 10 packets and ensure the CWND ends at 14 packets.
const numberOfAcks = 5
// At startup make sure we can send.
Expect(sender.TimeUntilSend(0)).To(BeZero())
// Make sure we can send.
Expect(sender.TimeUntilSend(0)).To(BeZero())
SendAvailableSendWindow()
for i := 0; i < numberOfAcks; i++ {
AckNPackets(2)
}
bytesToSend := sender.GetCongestionWindow()
// It's expected 2 acks will arrive when the bytes_in_flight are greater than
// half the CWND.
Expect(bytesToSend).To(Equal(defaultWindowTCP + protocol.DefaultTCPMSS*2*2))
})
It("exponential slow start", func() {
const numberOfAcks = 20
// At startup make sure we can send.
Expect(sender.TimeUntilSend(0)).To(BeZero())
Expect(sender.BandwidthEstimate()).To(BeZero())
// Make sure we can send.
Expect(sender.TimeUntilSend(0)).To(BeZero())
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
cwnd := sender.GetCongestionWindow()
Expect(cwnd).To(Equal(defaultWindowTCP + protocol.DefaultTCPMSS*2*numberOfAcks))
Expect(sender.BandwidthEstimate()).To(Equal(BandwidthFromDelta(cwnd, rttStats.SmoothedRTT())))
})
It("slow start packet loss", func() {
sender.SetNumEmulatedConnections(1)
const numberOfAcks = 10
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose a packet to exit slow start.
LoseNPackets(1)
packetsInRecoveryWindow := expectedSendWindow / protocol.DefaultTCPMSS
// We should now have fallen out of slow start with a reduced window.
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Recovery phase. We need to ack every packet in the recovery window before
// we exit recovery.
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
AckNPackets(int(packetsInRecoveryWindow))
SendAvailableSendWindow()
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// We need to ack an entire window before we increase CWND by 1.
AckNPackets(int(numberOfPacketsInWindow) - 2)
SendAvailableSendWindow()
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Next ack should increase cwnd by 1.
AckNPackets(1)
expectedSendWindow += protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Now RTO and ensure slow start gets reset.
Expect(sender.HybridSlowStart().Started()).To(BeTrue())
sender.OnRetransmissionTimeout(true)
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
})
It("slow start packet loss with large reduction", func() {
sender.SetSlowStartLargeReduction(true)
sender.SetNumEmulatedConnections(1)
const numberOfAcks = 10
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose a packet to exit slow start. We should now have fallen out of
// slow start with a window reduced by 1.
LoseNPackets(1)
expectedSendWindow -= protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose 5 packets in recovery and verify that congestion window is reduced
// further.
LoseNPackets(5)
expectedSendWindow -= 5 * protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
packetsInRecoveryWindow := expectedSendWindow / protocol.DefaultTCPMSS
// Recovery phase. We need to ack every packet in the recovery window before
// we exit recovery.
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
AckNPackets(int(packetsInRecoveryWindow))
SendAvailableSendWindow()
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// We need to ack the rest of the window before cwnd increases by 1.
AckNPackets(int(numberOfPacketsInWindow - 1))
SendAvailableSendWindow()
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Next ack should increase cwnd by 1.
AckNPackets(1)
expectedSendWindow += protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Now RTO and ensure slow start gets reset.
Expect(sender.HybridSlowStart().Started()).To(BeTrue())
sender.OnRetransmissionTimeout(true)
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
})
It("slow start half packet loss with large reduction", func() {
sender.SetSlowStartLargeReduction(true)
sender.SetNumEmulatedConnections(1)
const numberOfAcks = 10
for i := 0; i < numberOfAcks; i++ {
// Send our full send window in half sized packets.
SendAvailableSendWindowLen(protocol.DefaultTCPMSS / 2)
AckNPackets(2)
}
SendAvailableSendWindowLen(protocol.DefaultTCPMSS / 2)
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose a packet to exit slow start. We should now have fallen out of
// slow start with a window reduced by 1.
LoseNPackets(1)
expectedSendWindow -= protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose 10 packets in recovery and verify that congestion window is reduced
// by 5 packets.
LoseNPacketsLen(10, protocol.DefaultTCPMSS/2)
expectedSendWindow -= 5 * protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
})
// this test doesn't work any more after introducing the pacing needed for QUIC
PIt("no PRR when less than one packet in flight", func() {
SendAvailableSendWindow()
LoseNPackets(int(initialCongestionWindowPackets) - 1)
AckNPackets(1)
// PRR will allow 2 packets for every ack during recovery.
Expect(SendAvailableSendWindow()).To(Equal(2))
// Simulate abandoning all packets by supplying a bytes_in_flight of 0.
// PRR should now allow a packet to be sent, even though prr's state
// variables believe it has sent enough packets.
Expect(sender.TimeUntilSend(0)).To(BeZero())
})
It("slow start packet loss PRR", func() {
sender.SetNumEmulatedConnections(1)
// Test based on the first example in RFC6937.
// Ack 10 packets in 5 acks to raise the CWND to 20, as in the example.
const numberOfAcks = 5
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
LoseNPackets(1)
// We should now have fallen out of slow start with a reduced window.
sendWindowBeforeLoss := expectedSendWindow
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Testing TCP proportional rate reduction.
// We should send packets paced over the received acks for the remaining
// outstanding packets. The number of packets before we exit recovery is the
// original CWND minus the packet that has been lost and the one which
// triggered the loss.
remainingPacketsInRecovery := sendWindowBeforeLoss/protocol.DefaultTCPMSS - 2
for i := protocol.ByteCount(0); i < remainingPacketsInRecovery; i++ {
AckNPackets(1)
SendAvailableSendWindow()
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
}
// We need to ack another window before we increase CWND by 1.
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
for i := protocol.ByteCount(0); i < numberOfPacketsInWindow; i++ {
AckNPackets(1)
Expect(SendAvailableSendWindow()).To(Equal(1))
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
}
AckNPackets(1)
expectedSendWindow += protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
})
It("slow start burst packet loss PRR", func() {
sender.SetNumEmulatedConnections(1)
// Test based on the second example in RFC6937, though we also implement
// forward acknowledgements, so the first two incoming acks will trigger
// PRR immediately.
// Ack 20 packets in 10 acks to raise the CWND to 30.
const numberOfAcks = 10
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Lose one more than the congestion window reduction, so that after loss,
// bytes_in_flight is lesser than the congestion window.
sendWindowAfterLoss := protocol.ByteCount(renoBeta * float32(expectedSendWindow))
numPacketsToLose := (expectedSendWindow-sendWindowAfterLoss)/protocol.DefaultTCPMSS + 1
LoseNPackets(int(numPacketsToLose))
// Immediately after the loss, ensure at least one packet can be sent.
// Losses without subsequent acks can occur with timer based loss detection.
Expect(sender.TimeUntilSend(bytesInFlight)).To(BeZero())
AckNPackets(1)
// We should now have fallen out of slow start with a reduced window.
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Only 2 packets should be allowed to be sent, per PRR-SSRB
Expect(SendAvailableSendWindow()).To(Equal(2))
// Ack the next packet, which triggers another loss.
LoseNPackets(1)
AckNPackets(1)
// Send 2 packets to simulate PRR-SSRB.
Expect(SendAvailableSendWindow()).To(Equal(2))
// Ack the next packet, which triggers another loss.
LoseNPackets(1)
AckNPackets(1)
// Send 2 packets to simulate PRR-SSRB.
Expect(SendAvailableSendWindow()).To(Equal(2))
// Exit recovery and return to sending at the new rate.
for i := 0; i < numberOfAcks; i++ {
AckNPackets(1)
Expect(SendAvailableSendWindow()).To(Equal(1))
}
})
It("RTO congestion window", func() {
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
// Expect the window to decrease to the minimum once the RTO fires
// and slow start threshold to be set to 1/2 of the CWND.
sender.OnRetransmissionTimeout(true)
Expect(sender.GetCongestionWindow()).To(Equal(2 * protocol.DefaultTCPMSS))
Expect(sender.SlowstartThreshold()).To(Equal(5 * protocol.DefaultTCPMSS))
})
It("RTO congestion window no retransmission", func() {
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
// Expect the window to remain unchanged if the RTO fires but no
// packets are retransmitted.
sender.OnRetransmissionTimeout(false)
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
})
It("tcp cubic reset epoch on quiescence", func() {
const maxCongestionWindow = 50
const maxCongestionWindowBytes = maxCongestionWindow * protocol.DefaultTCPMSS
sender = NewCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*protocol.DefaultTCPMSS, maxCongestionWindowBytes)
numSent := SendAvailableSendWindow()
// Make sure we fall out of slow start.
savedCwnd := sender.GetCongestionWindow()
LoseNPackets(1)
Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow()))
// Ack the rest of the outstanding packets to get out of recovery.
for i := 1; i < numSent; i++ {
AckNPackets(1)
}
Expect(bytesInFlight).To(BeZero())
// Send a new window of data and ack all; cubic growth should occur.
savedCwnd = sender.GetCongestionWindow()
numSent = SendAvailableSendWindow()
for i := 0; i < numSent; i++ {
AckNPackets(1)
}
Expect(savedCwnd).To(BeNumerically("<", sender.GetCongestionWindow()))
Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow()))
Expect(bytesInFlight).To(BeZero())
// Quiescent time of 100 seconds
clock.Advance(100 * time.Second)
// Send new window of data and ack one packet. Cubic epoch should have
// been reset; ensure cwnd increase is not dramatic.
savedCwnd = sender.GetCongestionWindow()
SendAvailableSendWindow()
AckNPackets(1)
Expect(savedCwnd).To(BeNumerically("~", sender.GetCongestionWindow(), protocol.DefaultTCPMSS))
Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow()))
})
It("multiple losses in one window", func() {
SendAvailableSendWindow()
initialWindow := sender.GetCongestionWindow()
LosePacket(ackedPacketNumber + 1)
postLossWindow := sender.GetCongestionWindow()
Expect(initialWindow).To(BeNumerically(">", postLossWindow))
LosePacket(ackedPacketNumber + 3)
Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow))
LosePacket(packetNumber - 1)
Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow))
// Lose a later packet and ensure the window decreases.
LosePacket(packetNumber)
Expect(postLossWindow).To(BeNumerically(">", sender.GetCongestionWindow()))
})
It("2 connection congestion avoidance at end of recovery", func() {
sender.SetNumEmulatedConnections(2)
// Ack 10 packets in 5 acks to raise the CWND to 20.
const numberOfAcks = 5
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
LoseNPackets(1)
// We should now have fallen out of slow start with a reduced window.
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * sender.RenoBeta())
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// No congestion window growth should occur in recovery phase, i.e., until the
// currently outstanding 20 packets are acked.
for i := 0; i < 10; i++ {
// Send our full send window.
SendAvailableSendWindow()
Expect(sender.InRecovery()).To(BeTrue())
AckNPackets(2)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
}
Expect(sender.InRecovery()).To(BeFalse())
// Out of recovery now. Congestion window should not grow for half an RTT.
packetsInSendWindow := expectedSendWindow / protocol.DefaultTCPMSS
SendAvailableSendWindow()
AckNPackets(int(packetsInSendWindow/2 - 2))
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Next ack should increase congestion window by 1MSS.
SendAvailableSendWindow()
AckNPackets(2)
expectedSendWindow += protocol.DefaultTCPMSS
packetsInSendWindow++
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Congestion window should remain steady again for half an RTT.
SendAvailableSendWindow()
AckNPackets(int(packetsInSendWindow/2 - 1))
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Next ack should cause congestion window to grow by 1MSS.
SendAvailableSendWindow()
AckNPackets(2)
expectedSendWindow += protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
})
It("1 connection congestion avoidance at end of recovery", func() {
sender.SetNumEmulatedConnections(1)
// Ack 10 packets in 5 acks to raise the CWND to 20.
const numberOfAcks = 5
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
LoseNPackets(1)
// We should now have fallen out of slow start with a reduced window.
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// No congestion window growth should occur in recovery phase, i.e., until the
// currently outstanding 20 packets are acked.
for i := 0; i < 10; i++ {
// Send our full send window.
SendAvailableSendWindow()
Expect(sender.InRecovery()).To(BeTrue())
AckNPackets(2)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
}
Expect(sender.InRecovery()).To(BeFalse())
// Out of recovery now. Congestion window should not grow during RTT.
for i := protocol.ByteCount(0); i < expectedSendWindow/protocol.DefaultTCPMSS-2; i += 2 {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
}
// Next ack should cause congestion window to grow by 1MSS.
SendAvailableSendWindow()
AckNPackets(2)
expectedSendWindow += protocol.DefaultTCPMSS
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
})
It("reset after connection migration", func() {
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
// Starts with slow start.
sender.SetNumEmulatedConnections(1)
const numberOfAcks = 10
for i := 0; i < numberOfAcks; i++ {
// Send our full send window.
SendAvailableSendWindow()
AckNPackets(2)
}
SendAvailableSendWindow()
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
// Loses a packet to exit slow start.
LoseNPackets(1)
// We should now have fallen out of slow start with a reduced window. Slow
// start threshold is also updated.
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
Expect(sender.SlowstartThreshold()).To(Equal(expectedSendWindow))
// Resets cwnd and slow start threshold on connection migrations.
sender.OnConnectionMigration()
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
})
It("default max cwnd", func() {
sender = NewCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*protocol.DefaultTCPMSS, protocol.DefaultMaxCongestionWindow)
defaultMaxCongestionWindowPackets := protocol.DefaultMaxCongestionWindow / protocol.DefaultTCPMSS
for i := 1; i < int(defaultMaxCongestionWindowPackets); i++ {
sender.MaybeExitSlowStart()
sender.OnPacketAcked(protocol.PacketNumber(i), 1350, sender.GetCongestionWindow(), clock.Now())
}
Expect(sender.GetCongestionWindow()).To(Equal(protocol.DefaultMaxCongestionWindow))
})
It("limit cwnd increase in congestion avoidance", func() {
// Enable Cubic.
sender = NewCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*protocol.DefaultTCPMSS, MaxCongestionWindow)
numSent := SendAvailableSendWindow()
// Make sure we fall out of slow start.
savedCwnd := sender.GetCongestionWindow()
LoseNPackets(1)
Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow()))
// Ack the rest of the outstanding packets to get out of recovery.
for i := 1; i < numSent; i++ {
AckNPackets(1)
}
Expect(bytesInFlight).To(BeZero())
savedCwnd = sender.GetCongestionWindow()
SendAvailableSendWindow()
// Ack packets until the CWND increases.
for sender.GetCongestionWindow() == savedCwnd {
AckNPackets(1)
SendAvailableSendWindow()
}
// Bytes in flight may be larger than the CWND if the CWND isn't an exact
// multiple of the packet sizes being sent.
Expect(bytesInFlight).To(BeNumerically(">=", sender.GetCongestionWindow()))
savedCwnd = sender.GetCongestionWindow()
// Advance time 2 seconds waiting for an ack.
clock.Advance(2 * time.Second)
// Ack two packets. The CWND should increase by only one packet.
AckNPackets(2)
Expect(sender.GetCongestionWindow()).To(Equal(savedCwnd + protocol.DefaultTCPMSS))
})
})

View File

@ -0,0 +1,236 @@
package congestion
import (
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const numConnections uint32 = 2
const nConnectionBeta float32 = (float32(numConnections) - 1 + beta) / float32(numConnections)
const nConnectionBetaLastMax float32 = (float32(numConnections) - 1 + betaLastMax) / float32(numConnections)
const nConnectionAlpha float32 = 3 * float32(numConnections) * float32(numConnections) * (1 - nConnectionBeta) / (1 + nConnectionBeta)
const maxCubicTimeInterval = 30 * time.Millisecond
var _ = Describe("Cubic", func() {
var (
clock mockClock
cubic *Cubic
)
BeforeEach(func() {
clock = mockClock{}
cubic = NewCubic(&clock)
})
renoCwnd := func(currentCwnd protocol.ByteCount) protocol.ByteCount {
return currentCwnd + protocol.ByteCount(float32(protocol.DefaultTCPMSS)*nConnectionAlpha*float32(protocol.DefaultTCPMSS)/float32(currentCwnd))
}
cubicConvexCwnd := func(initialCwnd protocol.ByteCount, rtt, elapsedTime time.Duration) protocol.ByteCount {
offset := protocol.ByteCount((elapsedTime+rtt)/time.Microsecond) << 10 / 1000000
deltaCongestionWindow := 410 * offset * offset * offset * protocol.DefaultTCPMSS >> 40
return initialCwnd + deltaCongestionWindow
}
It("works above origin (with tighter bounds)", func() {
// Convex growth.
const rttMin = 100 * time.Millisecond
const rttMinS = float32(rttMin/time.Millisecond) / 1000.0
currentCwnd := 10 * protocol.DefaultTCPMSS
initialCwnd := currentCwnd
clock.Advance(time.Millisecond)
initialTime := clock.Now()
expectedFirstCwnd := renoCwnd(currentCwnd)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, initialTime)
Expect(expectedFirstCwnd).To(Equal(currentCwnd))
// Normal TCP phase.
// The maximum number of expected reno RTTs can be calculated by
// finding the point where the cubic curve and the reno curve meet.
maxRenoRtts := int(math.Sqrt(float64(nConnectionAlpha/(0.4*rttMinS*rttMinS*rttMinS))) - 2)
for i := 0; i < maxRenoRtts; i++ {
// Alternatively, we expect it to increase by one, every time we
// receive current_cwnd/Alpha acks back. (This is another way of
// saying we expect cwnd to increase by approximately Alpha once
// we receive current_cwnd number ofacks back).
numAcksThisEpoch := int(float32(currentCwnd/protocol.DefaultTCPMSS) / nConnectionAlpha)
initialCwndThisEpoch := currentCwnd
for n := 0; n < numAcksThisEpoch; n++ {
// Call once per ACK.
expectedNextCwnd := renoCwnd(currentCwnd)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
Expect(currentCwnd).To(Equal(expectedNextCwnd))
}
// Our byte-wise Reno implementation is an estimate. We expect
// the cwnd to increase by approximately one MSS every
// cwnd/kDefaultTCPMSS/Alpha acks, but it may be off by as much as
// half a packet for smaller values of current_cwnd.
cwndChangeThisEpoch := currentCwnd - initialCwndThisEpoch
Expect(cwndChangeThisEpoch).To(BeNumerically("~", protocol.DefaultTCPMSS, protocol.DefaultTCPMSS/2))
clock.Advance(100 * time.Millisecond)
}
for i := 0; i < 54; i++ {
maxAcksThisEpoch := currentCwnd / protocol.DefaultTCPMSS
interval := time.Duration(100*1000/maxAcksThisEpoch) * time.Microsecond
for n := 0; n < int(maxAcksThisEpoch); n++ {
clock.Advance(interval)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
// If we allow per-ack updates, every update is a small cubic update.
Expect(currentCwnd).To(Equal(expectedCwnd))
}
}
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
Expect(currentCwnd).To(Equal(expectedCwnd))
})
It("works above the origin with fine grained cubing", func() {
// Start the test with an artificially large cwnd to prevent Reno
// from over-taking cubic.
currentCwnd := 1000 * protocol.DefaultTCPMSS
initialCwnd := currentCwnd
rttMin := 100 * time.Millisecond
clock.Advance(time.Millisecond)
initialTime := clock.Now()
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
clock.Advance(600 * time.Millisecond)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
// We expect the algorithm to perform only non-zero, fine-grained cubic
// increases on every ack in this case.
for i := 0; i < 100; i++ {
clock.Advance(10 * time.Millisecond)
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
nextCwnd := cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
// Make sure we are performing cubic increases.
Expect(nextCwnd).To(Equal(expectedCwnd))
// Make sure that these are non-zero, less-than-packet sized increases.
Expect(nextCwnd).To(BeNumerically(">", currentCwnd))
cwndDelta := nextCwnd - currentCwnd
Expect(protocol.DefaultTCPMSS / 10).To(BeNumerically(">", cwndDelta))
currentCwnd = nextCwnd
}
})
It("handles per ack updates", func() {
// Start the test with a large cwnd and RTT, to force the first
// increase to be a cubic increase.
initialCwndPackets := 150
currentCwnd := protocol.ByteCount(initialCwndPackets) * protocol.DefaultTCPMSS
rttMin := 350 * time.Millisecond
// Initialize the epoch
clock.Advance(time.Millisecond)
// Keep track of the growth of the reno-equivalent cwnd.
rCwnd := renoCwnd(currentCwnd)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
initialCwnd := currentCwnd
// Simulate the return of cwnd packets in less than
// MaxCubicInterval() time.
maxAcks := int(float32(initialCwndPackets) / nConnectionAlpha)
interval := maxCubicTimeInterval / time.Duration(maxAcks+1)
// In this scenario, the first increase is dictated by the cubic
// equation, but it is less than one byte, so the cwnd doesn't
// change. Normally, without per-ack increases, any cwnd plateau
// will cause the cwnd to be pinned for MaxCubicTimeInterval(). If
// we enable per-ack updates, the cwnd will continue to grow,
// regardless of the temporary plateau.
clock.Advance(interval)
rCwnd = renoCwnd(rCwnd)
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(currentCwnd))
for i := 1; i < maxAcks; i++ {
clock.Advance(interval)
nextCwnd := cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
rCwnd = renoCwnd(rCwnd)
// The window shoud increase on every ack.
Expect(nextCwnd).To(BeNumerically(">", currentCwnd))
Expect(nextCwnd).To(Equal(rCwnd))
currentCwnd = nextCwnd
}
// After all the acks are returned from the epoch, we expect the
// cwnd to have increased by nearly one packet. (Not exactly one
// packet, because our byte-wise Reno algorithm is always a slight
// under-estimation). Without per-ack updates, the current_cwnd
// would otherwise be unchanged.
minimumExpectedIncrease := protocol.DefaultTCPMSS * 9 / 10
Expect(currentCwnd).To(BeNumerically(">", initialCwnd+minimumExpectedIncrease))
})
It("handles loss events", func() {
rttMin := 100 * time.Millisecond
currentCwnd := 422 * protocol.DefaultTCPMSS
expectedCwnd := renoCwnd(currentCwnd)
// Initialize the state.
clock.Advance(time.Millisecond)
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd))
// On the first loss, the last max congestion window is set to the
// congestion window before the loss.
preLossCwnd := currentCwnd
Expect(cubic.lastMaxCongestionWindow).To(BeZero())
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
Expect(cubic.lastMaxCongestionWindow).To(Equal(preLossCwnd))
currentCwnd = expectedCwnd
// On the second loss, the current congestion window has not yet
// reached the last max congestion window. The last max congestion
// window will be reduced by an additional backoff factor to allow
// for competition.
preLossCwnd = currentCwnd
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
currentCwnd = expectedCwnd
Expect(preLossCwnd).To(BeNumerically(">", cubic.lastMaxCongestionWindow))
expectedLastMax := protocol.ByteCount(float32(preLossCwnd) * nConnectionBetaLastMax)
Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax))
Expect(expectedCwnd).To(BeNumerically("<", cubic.lastMaxCongestionWindow))
// Simulate an increase, and check that we are below the origin.
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
Expect(cubic.lastMaxCongestionWindow).To(BeNumerically(">", currentCwnd))
// On the final loss, simulate the condition where the congestion
// window had a chance to grow nearly to the last congestion window.
currentCwnd = cubic.lastMaxCongestionWindow - 1
preLossCwnd = currentCwnd
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
expectedLastMax = preLossCwnd
Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax))
})
It("works below origin", func() {
// Concave growth.
rttMin := 100 * time.Millisecond
currentCwnd := 422 * protocol.DefaultTCPMSS
expectedCwnd := renoCwnd(currentCwnd)
// Initialize the state.
clock.Advance(time.Millisecond)
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd))
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
currentCwnd = expectedCwnd
// First update after loss to initialize the epoch.
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
// Cubic phase.
for i := 0; i < 40; i++ {
clock.Advance(100 * time.Millisecond)
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
}
expectedCwnd = 553632
Expect(currentCwnd).To(Equal(expectedCwnd))
})
})

View File

@ -0,0 +1,111 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// Note(pwestin): the magic clamping numbers come from the original code in
// tcp_cubic.c.
const hybridStartLowWindow = protocol.ByteCount(16)
// Number of delay samples for detecting the increase of delay.
const hybridStartMinSamples = uint32(8)
// Exit slow start if the min rtt has increased by more than 1/8th.
const hybridStartDelayFactorExp = 3 // 2^3 = 8
// The original paper specifies 2 and 8ms, but those have changed over time.
const hybridStartDelayMinThresholdUs = int64(4000)
const hybridStartDelayMaxThresholdUs = int64(16000)
// HybridSlowStart implements the TCP hybrid slow start algorithm
type HybridSlowStart struct {
endPacketNumber protocol.PacketNumber
lastSentPacketNumber protocol.PacketNumber
started bool
currentMinRTT time.Duration
rttSampleCount uint32
hystartFound bool
}
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) {
s.endPacketNumber = lastSent
s.currentMinRTT = 0
s.rttSampleCount = 0
s.started = true
}
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool {
return s.endPacketNumber < ack
}
// ShouldExitSlowStart should be called on every new ack frame, since a new
// RTT measurement can be made then.
// rtt: the RTT for this ack packet.
// minRTT: is the lowest delay (RTT) we have seen during the session.
// congestionWindow: the congestion window in packets.
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool {
if !s.started {
// Time to start the hybrid slow start.
s.StartReceiveRound(s.lastSentPacketNumber)
}
if s.hystartFound {
return true
}
// Second detection parameter - delay increase detection.
// Compare the minimum delay (s.currentMinRTT) of the current
// burst of packets relative to the minimum delay during the session.
// Note: we only look at the first few(8) packets in each burst, since we
// only want to compare the lowest RTT of the burst relative to previous
// bursts.
s.rttSampleCount++
if s.rttSampleCount <= hybridStartMinSamples {
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
s.currentMinRTT = latestRTT
}
}
// We only need to check this once per round.
if s.rttSampleCount == hybridStartMinSamples {
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
minRTTincreaseThresholdUs = utils.MinInt64(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
minRTTincreaseThreshold := time.Duration(utils.MaxInt64(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
s.hystartFound = true
}
}
// Exit from slow start if the cwnd is greater than 16 and
// increasing delay is found.
return congestionWindow >= hybridStartLowWindow && s.hystartFound
}
// OnPacketSent is called when a packet was sent
func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) {
s.lastSentPacketNumber = packetNumber
}
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
// the round when the final packet of the burst is received and start it on
// the next incoming ack.
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) {
if s.IsEndOfRound(ackedPacketNumber) {
s.started = false
}
}
// Started returns true if started
func (s *HybridSlowStart) Started() bool {
return s.started
}
// Restart the slow start phase
func (s *HybridSlowStart) Restart() {
s.started = false
s.hystartFound = false
}

View File

@ -0,0 +1,75 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Hybrid slow start", func() {
var (
slowStart HybridSlowStart
)
BeforeEach(func() {
slowStart = HybridSlowStart{}
})
It("works in a simple case", func() {
packetNumber := protocol.PacketNumber(1)
endPacketNumber := protocol.PacketNumber(3)
slowStart.StartReceiveRound(endPacketNumber)
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
// Test duplicates.
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
// Test without a new registered end_packet_number;
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
endPacketNumber = 20
slowStart.StartReceiveRound(endPacketNumber)
for packetNumber < endPacketNumber {
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
}
packetNumber++
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
})
It("works with delay", func() {
rtt := 60 * time.Millisecond
// We expect to detect the increase at +1/8 of the RTT; hence at a typical
// RTT of 60ms the detection will happen at 67.5 ms.
const hybridStartMinSamples = 8 // Number of acks required to trigger.
endPacketNumber := protocol.PacketNumber(1)
endPacketNumber++
slowStart.StartReceiveRound(endPacketNumber)
// Will not trigger since our lowest RTT in our burst is the same as the long
// term RTT provided.
for n := 0; n < hybridStartMinSamples; n++ {
Expect(slowStart.ShouldExitSlowStart(rtt+time.Duration(n)*time.Millisecond, rtt, 100)).To(BeFalse())
}
endPacketNumber++
slowStart.StartReceiveRound(endPacketNumber)
for n := 1; n < hybridStartMinSamples; n++ {
Expect(slowStart.ShouldExitSlowStart(rtt+(time.Duration(n)+10)*time.Millisecond, rtt, 100)).To(BeFalse())
}
// Expect to trigger since all packets in this burst was above the long term
// RTT provided.
Expect(slowStart.ShouldExitSlowStart(rtt+10*time.Millisecond, rtt, 100)).To(BeTrue())
})
})

View File

@ -0,0 +1,36 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// A SendAlgorithm performs congestion control and calculates the congestion window
type SendAlgorithm interface {
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
GetCongestionWindow() protocol.ByteCount
MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
SetNumEmulatedConnections(n int)
OnRetransmissionTimeout(packetsRetransmitted bool)
OnConnectionMigration()
// Experiments
SetSlowStartLargeReduction(enabled bool)
}
// SendAlgorithmWithDebugInfo adds some debug functions to SendAlgorithm
type SendAlgorithmWithDebugInfo interface {
SendAlgorithm
BandwidthEstimate() Bandwidth
// Stuff only used in testing
HybridSlowStart() *HybridSlowStart
SlowstartThreshold() protocol.ByteCount
RenoBeta() float32
InRecovery() bool
}

View File

@ -0,0 +1,54 @@
package congestion
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
type PrrSender struct {
bytesSentSinceLoss protocol.ByteCount
bytesDeliveredSinceLoss protocol.ByteCount
ackCountSinceLoss protocol.ByteCount
bytesInFlightBeforeLoss protocol.ByteCount
}
// OnPacketSent should be called after a packet was sent
func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) {
p.bytesSentSinceLoss += sentBytes
}
// OnPacketLost should be called on the first loss that triggers a recovery
// period and all other methods in this class should only be called when in
// recovery.
func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) {
p.bytesSentSinceLoss = 0
p.bytesInFlightBeforeLoss = priorInFlight
p.bytesDeliveredSinceLoss = 0
p.ackCountSinceLoss = 0
}
// OnPacketAcked should be called after a packet was acked
func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) {
p.bytesDeliveredSinceLoss += ackedBytes
p.ackCountSinceLoss++
}
// CanSend returns if packets can be sent
func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool {
// Return QuicTime::Zero In order to ensure limited transmit always works.
if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS {
return true
}
if congestionWindow > bytesInFlight {
// During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead
// of sending the entire available window. This prevents burst retransmits
// when more packets are lost than the CWND reduction.
// limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS
return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss
}
// Implement Proportional Rate Reduction (RFC6937).
// Checks a simplified version of the PRR formula that doesn't use division:
// AvailableSendWindow =
// CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent
return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss
}

Some files were not shown because too many files have changed in this diff Show More