mirror of https://github.com/v2ray/v2ray-core
vendor quic 44
parent
bb8cab9cc7
commit
84f8bca01c
|
@ -0,0 +1,44 @@
|
|||
# Changelog
|
||||
|
||||
## v0.10.0 (2018-08-28)
|
||||
|
||||
- Add support for QUIC 44, drop support for QUIC 42.
|
||||
|
||||
## v0.9.0 (2018-08-15)
|
||||
|
||||
- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC).
|
||||
- Split Session.Close into one method for regular closing and one for closing with an error.
|
||||
|
||||
## v0.8.0 (2018-06-26)
|
||||
|
||||
- Add support for unidirectional streams (for IETF QUIC).
|
||||
- Add a `quic.Config` option for the maximum number of incoming streams.
|
||||
- Add support for QUIC 42 and 43.
|
||||
- Add dial functions that use a context.
|
||||
- Multiplex clients on a net.PacketConn, when using Dial(conn).
|
||||
|
||||
## v0.7.0 (2018-02-03)
|
||||
|
||||
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
|
||||
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
|
||||
- Expose the `ConnectionState` in the `Session` (experimental API).
|
||||
- Implement packet pacing.
|
||||
|
||||
## v0.6.0 (2017-12-12)
|
||||
|
||||
- Add support for QUIC 39, drop support for QUIC 35 - 37
|
||||
- Added `quic.Config` options for maximal flow control windows
|
||||
- Add a `quic.Config` option for QUIC versions
|
||||
- Add a `quic.Config` option to request omission of the connection ID from a server
|
||||
- Add a `quic.Config` option to configure the source address validation
|
||||
- Add a `quic.Config` option to configure the handshake timeout
|
||||
- Add a `quic.Config` option to configure the idle timeout
|
||||
- Add a `quic.Config` option to configure keep-alive
|
||||
- Rename the STK to Cookie
|
||||
- Implement `net.Conn`-style deadlines for streams
|
||||
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details.
|
||||
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details.
|
||||
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
|
||||
- Changed `h2quic.Server.Serve()` to accept a `net.PacketConn`
|
||||
- Drop support for Go 1.7 and 1.8.
|
||||
- Various bugfixes
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2016 the quic-go authors & Google, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,73 @@
|
|||
# A QUIC implementation in pure Go
|
||||
|
||||
<img src="docs/quic.png" width=303 height=124>
|
||||
|
||||
[](https://godoc.org/github.com/lucas-clemente/quic-go)
|
||||
[](https://travis-ci.org/lucas-clemente/quic-go)
|
||||
[](https://circleci.com/gh/lucas-clemente/quic-go)
|
||||
[](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
|
||||
[](https://codecov.io/gh/lucas-clemente/quic-go/)
|
||||
|
||||
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go.
|
||||
|
||||
## Roadmap
|
||||
|
||||
quic-go is compatible with the current version(s) of Google Chrome and QUIC as deployed on Google's servers. We're actively tracking the development of the Chrome code to ensure compatibility as the protocol evolves. In that process, we're dropping support for old QUIC versions.
|
||||
As Google's QUIC versions are expected to converge towards the [IETF QUIC draft](https://github.com/quicwg/base-drafts), quic-go will eventually implement that draft.
|
||||
|
||||
## Guides
|
||||
|
||||
We currently support Go 1.9+.
|
||||
|
||||
Installing and updating dependencies:
|
||||
|
||||
go get -t -u ./...
|
||||
|
||||
Running tests:
|
||||
|
||||
go test ./...
|
||||
|
||||
### Running the example server
|
||||
|
||||
go run example/main.go -www /var/www/
|
||||
|
||||
Using the `quic_client` from chromium:
|
||||
|
||||
quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
|
||||
|
||||
Using Chrome:
|
||||
|
||||
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
|
||||
|
||||
### QUIC without HTTP/2
|
||||
|
||||
Take a look at [this echo example](example/echo/echo.go).
|
||||
|
||||
### Using the example client
|
||||
|
||||
go run example/client/main.go https://clemente.io
|
||||
|
||||
## Usage
|
||||
|
||||
### As a server
|
||||
|
||||
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
|
||||
|
||||
```go
|
||||
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
|
||||
h2quic.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
|
||||
```
|
||||
|
||||
### As a client
|
||||
|
||||
See the [example client](example/client/main.go). Use a `h2quic.RoundTripper` as a `Transport` in a `http.Client`.
|
||||
|
||||
```go
|
||||
http.Client{
|
||||
Transport: &h2quic.RoundTripper{},
|
||||
}
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/lucas-clemente/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.
|
|
@ -0,0 +1,35 @@
|
|||
version: "{build}"
|
||||
|
||||
os: Windows Server 2012 R2
|
||||
|
||||
environment:
|
||||
GOPATH: c:\gopath
|
||||
CGO_ENABLED: 0
|
||||
TIMESCALE_FACTOR: 20
|
||||
matrix:
|
||||
- GOARCH: 386
|
||||
- GOARCH: amd64
|
||||
|
||||
clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
|
||||
|
||||
install:
|
||||
- rmdir c:\go /s /q
|
||||
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.11.windows-amd64.zip
|
||||
- 7z x go1.11.windows-amd64.zip -y -oC:\ > NUL
|
||||
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
|
||||
- echo %PATH%
|
||||
- echo %GOPATH%
|
||||
- git submodule update --init --recursive
|
||||
- go get github.com/onsi/ginkgo/ginkgo
|
||||
- go get github.com/onsi/gomega
|
||||
- go version
|
||||
- go env
|
||||
- go get -v -t ./...
|
||||
|
||||
build_script:
|
||||
- ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace -skipPackage benchmark,integrationtests
|
||||
- ginkgo -randomizeAllSpecs -randomizeSuites -trace benchmark -- -samples=1
|
||||
|
||||
test: off
|
||||
|
||||
deploy: off
|
|
@ -0,0 +1,26 @@
|
|||
package benchmark
|
||||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBenchmark(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Benchmark Suite")
|
||||
}
|
||||
|
||||
var (
|
||||
size int // file size in MB, will be read from flags
|
||||
samples int // number of samples for Measure, will be read from flags
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.IntVar(&size, "size", 50, "data length (in MB)")
|
||||
flag.IntVar(&samples, "samples", 6, "number of samples")
|
||||
flag.Parse()
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
package benchmark
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func init() {
|
||||
var _ = Describe("Benchmarks", func() {
|
||||
dataLen := size * /* MB */ 1e6
|
||||
data := make([]byte, dataLen)
|
||||
rand.Seed(GinkgoRandomSeed())
|
||||
rand.Read(data) // no need to check for an error. math.Rand.Read never errors
|
||||
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
Context(fmt.Sprintf("with version %s", version), func() {
|
||||
Measure(fmt.Sprintf("transferring a %d MB file", size), func(b Benchmarker) {
|
||||
var ln quic.Listener
|
||||
serverAddr := make(chan net.Addr)
|
||||
handshakeChan := make(chan struct{})
|
||||
// start the server
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
ln, err = quic.ListenAddr(
|
||||
"localhost:0",
|
||||
testdata.GetTLSConfig(),
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAddr <- ln.Addr()
|
||||
sess, err := ln.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// wait for the client to complete the handshake before sending the data
|
||||
// this should not be necessary, but due to timing issues on the CIs, this is necessary to avoid sending too many undecryptable packets
|
||||
<-handshakeChan
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = str.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
|
||||
// start the client
|
||||
addr := <-serverAddr
|
||||
sess, err := quic.DialAddr(
|
||||
addr.String(),
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(handshakeChan)
|
||||
str, err := sess.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
// measure the time it takes to download the dataLen bytes
|
||||
// note we're measuring the time for the transfer, i.e. excluding the handshake
|
||||
runtime := b.Time("transfer time", func() {
|
||||
_, err := io.Copy(buf, str)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
Expect(buf.Bytes()).To(Equal(data))
|
||||
|
||||
b.RecordValue("transfer rate [MB/s]", float64(dataLen)/1e6/runtime.Seconds())
|
||||
|
||||
ln.Close()
|
||||
sess.Close()
|
||||
}, samples)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var bufferPool sync.Pool
|
||||
|
||||
func getPacketBuffer() *[]byte {
|
||||
return bufferPool.Get().(*[]byte)
|
||||
}
|
||||
|
||||
func putPacketBuffer(buf *[]byte) {
|
||||
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
|
||||
panic("putPacketBuffer called with packet of wrong size!")
|
||||
}
|
||||
bufferPool.Put(buf)
|
||||
}
|
||||
|
||||
func init() {
|
||||
bufferPool.New = func() interface{} {
|
||||
b := make([]byte, 0, protocol.MaxReceivePacketSize)
|
||||
return &b
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Buffer Pool", func() {
|
||||
It("returns buffers of cap", func() {
|
||||
buf := *getPacketBuffer()
|
||||
Expect(buf).To(HaveCap(int(protocol.MaxReceivePacketSize)))
|
||||
})
|
||||
|
||||
It("panics if wrong-sized buffers are passed", func() {
|
||||
Expect(func() {
|
||||
putPacketBuffer(&[]byte{0})
|
||||
}).To(Panic())
|
||||
})
|
||||
})
|
|
@ -0,0 +1,595 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conn connection
|
||||
// If the client is created with DialAddr, we create a packet conn.
|
||||
// If it is started with Dial, we take a packet conn as a parameter.
|
||||
createdPacketConn bool
|
||||
|
||||
hostname string
|
||||
|
||||
packetHandlers packetHandlerManager
|
||||
|
||||
token []byte
|
||||
numRetries int
|
||||
|
||||
versionNegotiated bool // has the server accepted our version
|
||||
receivedVersionNegotiationPacket bool
|
||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||
|
||||
tlsConf *tls.Config
|
||||
mintConf *mint.Config
|
||||
config *Config
|
||||
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
|
||||
initialVersion protocol.VersionNumber
|
||||
version protocol.VersionNumber
|
||||
|
||||
handshakeChan chan struct{}
|
||||
closeCallback func(protocol.ConnectionID)
|
||||
|
||||
session quicSession
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ packetHandler = &client{}
|
||||
|
||||
var (
|
||||
// make it possible to mock connection ID generation in the tests
|
||||
generateConnectionID = protocol.GenerateConnectionID
|
||||
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
|
||||
)
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddr(
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
return DialAddrContext(context.Background(), addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
func DialAddrContext(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// The host parameter is used for SNI.
|
||||
func Dial(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
|
||||
// The host parameter is used for SNI.
|
||||
func DialContext(
|
||||
ctx context.Context,
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
|
||||
}
|
||||
|
||||
func dialContext(
|
||||
ctx context.Context,
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
createdPacketConn bool,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
if !createdPacketConn {
|
||||
for _, v := range config.Versions {
|
||||
if v == protocol.Version44 {
|
||||
return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
|
||||
}
|
||||
}
|
||||
}
|
||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.packetHandlers = packetHandlers
|
||||
if err := c.dial(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
func newClient(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
config *Config,
|
||||
tlsConf *tls.Config,
|
||||
host string,
|
||||
closeCallback func(protocol.ConnectionID),
|
||||
createdPacketConn bool,
|
||||
) (*client, error) {
|
||||
var hostname string
|
||||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
if hostname == "" {
|
||||
var err error
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// check that all versions are actually supported
|
||||
if config != nil {
|
||||
for _, v := range config.Versions {
|
||||
if !protocol.IsValidVersion(v) {
|
||||
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
onClose := func(protocol.ConnectionID) {}
|
||||
if closeCallback != nil {
|
||||
onClose = closeCallback
|
||||
}
|
||||
c := &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
createdPacketConn: createdPacketConn,
|
||||
hostname: hostname,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
closeCallback: onClose,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
return c, c.generateConnectionIDs()
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
versions := config.Versions
|
||||
if len(versions) == 0 {
|
||||
versions = protocol.SupportedVersions
|
||||
}
|
||||
|
||||
handshakeTimeout := protocol.DefaultHandshakeTimeout
|
||||
if config.HandshakeTimeout != 0 {
|
||||
handshakeTimeout = config.HandshakeTimeout
|
||||
}
|
||||
idleTimeout := protocol.DefaultIdleTimeout
|
||||
if config.IdleTimeout != 0 {
|
||||
idleTimeout = config.IdleTimeout
|
||||
}
|
||||
|
||||
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
|
||||
if maxReceiveStreamFlowControlWindow == 0 {
|
||||
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
|
||||
}
|
||||
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
|
||||
if maxReceiveConnectionFlowControlWindow == 0 {
|
||||
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
|
||||
}
|
||||
maxIncomingStreams := config.MaxIncomingStreams
|
||||
if maxIncomingStreams == 0 {
|
||||
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
|
||||
} else if maxIncomingStreams < 0 {
|
||||
maxIncomingStreams = 0
|
||||
}
|
||||
maxIncomingUniStreams := config.MaxIncomingUniStreams
|
||||
if maxIncomingUniStreams == 0 {
|
||||
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
|
||||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
connIDLen := config.ConnectionIDLength
|
||||
if connIDLen == 0 && !createdPacketConn {
|
||||
connIDLen = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
for _, v := range versions {
|
||||
if v == protocol.Version44 {
|
||||
connIDLen = 0
|
||||
}
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Versions: versions,
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
IdleTimeout: idleTimeout,
|
||||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||
ConnectionIDLength: connIDLen,
|
||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
KeepAlive: config.KeepAlive,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) generateConnectionIDs() error {
|
||||
connIDLen := protocol.ConnectionIDLenGQUIC
|
||||
if c.version.UsesTLS() {
|
||||
connIDLen = c.config.ConnectionIDLength
|
||||
}
|
||||
srcConnID, err := generateConnectionID(connIDLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destConnID := srcConnID
|
||||
if c.version.UsesTLS() {
|
||||
destConnID, err = generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.srcConnID = srcConnID
|
||||
c.destConnID = destConnID
|
||||
if c.version == protocol.Version44 {
|
||||
c.srcConnID = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
var err error
|
||||
if c.version.UsesTLS() {
|
||||
err = c.dialTLS(ctx)
|
||||
} else {
|
||||
err = c.dialGQUIC(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialGQUIC(ctx context.Context) error {
|
||||
if err := c.createNewGQUICSession(); err != nil {
|
||||
return err
|
||||
}
|
||||
err := c.establishSecureConnection(ctx)
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialTLS(ctx context.Context) error {
|
||||
params := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
IdleTimeout: c.config.IdleTimeout,
|
||||
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
|
||||
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
|
||||
DisableMigration: true,
|
||||
}
|
||||
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
|
||||
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mintConf.ExtensionHandler = extHandler
|
||||
mintConf.ServerName = c.hostname
|
||||
c.mintConf = mintConf
|
||||
|
||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.establishSecureConnection(ctx)
|
||||
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||
// It returns:
|
||||
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||
// - any other error that might occur
|
||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||
func (c *client) establishSecureConnection(ctx context.Context) error {
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
err := c.session.run() // returns as soon as the session is closed
|
||||
if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
|
||||
c.conn.Close()
|
||||
}
|
||||
errorChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// The session will send a PeerGoingAway error to the server.
|
||||
c.session.Close()
|
||||
return ctx.Err()
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
case <-c.handshakeChan:
|
||||
// handshake successfully completed
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handlePacket(p *receivedPacket) {
|
||||
if err := c.handlePacketImpl(p); err != nil {
|
||||
c.logger.Errorf("error handling packet: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handlePacketImpl(p *receivedPacket) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// handle Version Negotiation Packets
|
||||
if p.header.IsVersionNegotiation {
|
||||
err := c.handleVersionNegotiationPacket(p.header)
|
||||
if err != nil {
|
||||
c.session.destroy(err)
|
||||
}
|
||||
// version negotiation packets have no payload
|
||||
return err
|
||||
}
|
||||
|
||||
if !c.version.UsesIETFHeaderFormat() {
|
||||
connID := p.header.DestConnectionID
|
||||
// reject packets with truncated connection id if we didn't request truncation
|
||||
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
|
||||
return errors.New("received packet with truncated connection ID, but didn't request truncation")
|
||||
}
|
||||
// reject packets with the wrong connection ID
|
||||
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
|
||||
}
|
||||
if p.header.ResetFlag {
|
||||
return c.handlePublicReset(p)
|
||||
}
|
||||
} else {
|
||||
// reject packets with the wrong connection ID
|
||||
if !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
|
||||
}
|
||||
}
|
||||
|
||||
if p.header.IsLongHeader {
|
||||
switch p.header.Type {
|
||||
case protocol.PacketTypeRetry:
|
||||
c.handleRetryPacket(p.header)
|
||||
return nil
|
||||
case protocol.PacketTypeHandshake, protocol.PacketType0RTT:
|
||||
default:
|
||||
return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// this is the first packet we are receiving
|
||||
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
|
||||
if !c.versionNegotiated {
|
||||
c.versionNegotiated = true
|
||||
}
|
||||
|
||||
c.session.handlePacket(p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handlePublicReset(p *receivedPacket) error {
|
||||
cr := c.conn.RemoteAddr()
|
||||
// check if the remote address and the connection ID match
|
||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
||||
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return errors.New("Received a spoofed Public Reset")
|
||||
}
|
||||
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
}
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||
c.logger.Debugf("Received a delayed Version Negotiation Packet.")
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, v := range hdr.SupportedVersions {
|
||||
if v == c.version {
|
||||
// the version negotiation packet contains the version that we offered
|
||||
// this might be a packet sent by an attacker (or by a terribly broken server implementation)
|
||||
// ignore it
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
|
||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if !ok {
|
||||
return qerr.InvalidVersion
|
||||
}
|
||||
c.receivedVersionNegotiationPacket = true
|
||||
c.negotiatedVersions = hdr.SupportedVersions
|
||||
|
||||
// switch to negotiated version
|
||||
c.initialVersion = c.version
|
||||
c.version = newVersion
|
||||
if err := c.generateConnectionIDs(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
|
||||
c.session.destroy(errCloseSessionForNewVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleRetryPacket(hdr *wire.Header) {
|
||||
c.logger.Debugf("<- Received Retry")
|
||||
hdr.Log(c.logger)
|
||||
// A server that performs multiple retries must use a source connection ID of at least 8 bytes.
|
||||
// Only a server that won't send additional Retries can use shorter connection IDs.
|
||||
if hdr.OrigDestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
|
||||
c.logger.Debugf("Received a Retry with a too short Original Destination Connection ID: %d bytes, must have at least %d bytes.", hdr.OrigDestConnectionID.Len(), protocol.MinConnectionIDLenInitial)
|
||||
return
|
||||
}
|
||||
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
|
||||
c.logger.Debugf("Received spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
|
||||
return
|
||||
}
|
||||
c.numRetries++
|
||||
if c.numRetries > protocol.MaxRetries {
|
||||
c.session.destroy(qerr.CryptoTooManyRejects)
|
||||
return
|
||||
}
|
||||
c.destConnID = hdr.SrcConnectionID
|
||||
c.token = hdr.Token
|
||||
c.session.destroy(errCloseSessionForRetry)
|
||||
}
|
||||
|
||||
func (c *client) createNewGQUICSession() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: c.closeCallback,
|
||||
}
|
||||
sess, err := newClientSession(
|
||||
c.conn,
|
||||
runner,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.tlsConf,
|
||||
c.config,
|
||||
c.initialVersion,
|
||||
c.negotiatedVersions,
|
||||
c.logger,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.session = sess
|
||||
c.packetHandlers.Add(c.srcConnID, c)
|
||||
if c.config.RequestConnectionIDOmission {
|
||||
c.packetHandlers.Add(protocol.ConnectionID{}, c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) createNewTLSSession(
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
version protocol.VersionNumber,
|
||||
) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: c.closeCallback,
|
||||
}
|
||||
sess, err := newTLSClientSession(
|
||||
c.conn,
|
||||
runner,
|
||||
c.token,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.config,
|
||||
c.mintConf,
|
||||
paramsChan,
|
||||
1,
|
||||
c.logger,
|
||||
c.version,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.session = sess
|
||||
c.packetHandlers.Add(c.srcConnID, c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if c.session == nil {
|
||||
return nil
|
||||
}
|
||||
return c.session.Close()
|
||||
}
|
||||
|
||||
func (c *client) destroy(e error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if c.session == nil {
|
||||
return
|
||||
}
|
||||
c.session.destroy(e)
|
||||
}
|
||||
|
||||
func (c *client) GetVersion() protocol.VersionNumber {
|
||||
c.mutex.Lock()
|
||||
v := c.version
|
||||
c.mutex.Unlock()
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *client) GetPerspective() protocol.Perspective {
|
||||
return protocol.PerspectiveClient
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,18 @@
|
|||
coverage:
|
||||
round: nearest
|
||||
ignore:
|
||||
- streams_map_incoming_bidi.go
|
||||
- streams_map_incoming_uni.go
|
||||
- streams_map_outgoing_bidi.go
|
||||
- streams_map_outgoing_uni.go
|
||||
- h2quic/gzipreader.go
|
||||
- h2quic/response.go
|
||||
- internal/ackhandler/packet_linkedlist.go
|
||||
- internal/utils/byteinterval_linkedlist.go
|
||||
- internal/utils/packetinterval_linkedlist.go
|
||||
- internal/utils/linkedlist/linkedlist.go
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
threshold: 0.5
|
||||
patch: false
|
|
@ -0,0 +1,54 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type connection interface {
|
||||
Write([]byte) error
|
||||
Read([]byte) (int, net.Addr, error)
|
||||
Close() error
|
||||
LocalAddr() net.Addr
|
||||
RemoteAddr() net.Addr
|
||||
SetCurrentRemoteAddr(net.Addr)
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
pconn net.PacketConn
|
||||
currentAddr net.Addr
|
||||
}
|
||||
|
||||
var _ connection = &conn{}
|
||||
|
||||
func (c *conn) Write(p []byte) error {
|
||||
_, err := c.pconn.WriteTo(p, c.currentAddr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) Read(p []byte) (int, net.Addr, error) {
|
||||
return c.pconn.ReadFrom(p)
|
||||
}
|
||||
|
||||
func (c *conn) SetCurrentRemoteAddr(addr net.Addr) {
|
||||
c.mutex.Lock()
|
||||
c.currentAddr = addr
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
return c.pconn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
c.mutex.RLock()
|
||||
addr := c.currentAddr
|
||||
c.mutex.RUnlock()
|
||||
return addr
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
return c.pconn.Close()
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockPacketConn struct {
|
||||
addr net.Addr
|
||||
dataToRead chan []byte
|
||||
dataReadFrom net.Addr
|
||||
readErr error
|
||||
dataWritten bytes.Buffer
|
||||
dataWrittenTo net.Addr
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newMockPacketConn() *mockPacketConn {
|
||||
return &mockPacketConn{
|
||||
dataToRead: make(chan []byte, 1000),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
if c.readErr != nil {
|
||||
return 0, nil, c.readErr
|
||||
}
|
||||
data, ok := <-c.dataToRead
|
||||
if !ok {
|
||||
return 0, nil, errors.New("connection closed")
|
||||
}
|
||||
n := copy(b, data)
|
||||
return n, c.dataReadFrom, nil
|
||||
}
|
||||
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
c.dataWrittenTo = addr
|
||||
return c.dataWritten.Write(b)
|
||||
}
|
||||
func (c *mockPacketConn) Close() error {
|
||||
if !c.closed {
|
||||
close(c.dataToRead)
|
||||
}
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr }
|
||||
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
|
||||
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
|
||||
func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { panic("not implemented") }
|
||||
|
||||
var _ net.PacketConn = &mockPacketConn{}
|
||||
|
||||
var _ = Describe("Connection", func() {
|
||||
var c *conn
|
||||
var packetConn *mockPacketConn
|
||||
|
||||
BeforeEach(func() {
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.IPv4(192, 168, 100, 200),
|
||||
Port: 1337,
|
||||
}
|
||||
packetConn = newMockPacketConn()
|
||||
c = &conn{
|
||||
currentAddr: addr,
|
||||
pconn: packetConn,
|
||||
}
|
||||
})
|
||||
|
||||
It("writes", func() {
|
||||
err := c.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packetConn.dataWritten.Bytes()).To(Equal([]byte("foobar")))
|
||||
Expect(packetConn.dataWrittenTo.String()).To(Equal("192.168.100.200:1337"))
|
||||
})
|
||||
|
||||
It("reads", func() {
|
||||
packetConn.dataToRead <- []byte("foo")
|
||||
packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336}
|
||||
p := make([]byte, 10)
|
||||
n, raddr, err := c.Read(p)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(raddr.String()).To(Equal("127.0.0.1:1336"))
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(p[0:3]).To(Equal([]byte("foo")))
|
||||
})
|
||||
|
||||
It("gets the remote address", func() {
|
||||
Expect(c.RemoteAddr().String()).To(Equal("192.168.100.200:1337"))
|
||||
})
|
||||
|
||||
It("gets the local address", func() {
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.IPv4(192, 168, 0, 1),
|
||||
Port: 1234,
|
||||
}
|
||||
packetConn.addr = addr
|
||||
Expect(c.LocalAddr()).To(Equal(addr))
|
||||
})
|
||||
|
||||
It("changes the remote address", func() {
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 7331,
|
||||
}
|
||||
c.SetCurrentRemoteAddr(addr)
|
||||
Expect(c.RemoteAddr().String()).To(Equal(addr.String()))
|
||||
})
|
||||
|
||||
It("closes", func() {
|
||||
err := c.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packetConn.closed).To(BeTrue())
|
||||
})
|
||||
})
|
|
@ -0,0 +1,41 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStream interface {
|
||||
StreamID() protocol.StreamID
|
||||
io.Reader
|
||||
io.Writer
|
||||
handleStreamFrame(*wire.StreamFrame) error
|
||||
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
|
||||
closeForShutdown(error)
|
||||
setReadOffset(protocol.ByteCount)
|
||||
// methods needed for flow control
|
||||
getWindowUpdate() protocol.ByteCount
|
||||
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
|
||||
}
|
||||
|
||||
type cryptoStreamImpl struct {
|
||||
*stream
|
||||
}
|
||||
|
||||
var _ cryptoStream = &cryptoStreamImpl{}
|
||||
|
||||
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream {
|
||||
str := newStream(version.CryptoStreamID(), sender, flowController, version)
|
||||
return &cryptoStreamImpl{str}
|
||||
}
|
||||
|
||||
// SetReadOffset sets the read offset.
|
||||
// It is only needed for the crypto stream.
|
||||
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||
func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) {
|
||||
s.receiveStream.readOffset = offset
|
||||
s.receiveStream.frameQueue.readPos = offset
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Crypto Stream", func() {
|
||||
var (
|
||||
str *cryptoStreamImpl
|
||||
mockSender *MockStreamSender
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
mockSender = NewMockStreamSender(mockCtrl)
|
||||
str = newCryptoStream(mockSender, nil, protocol.VersionWhatever).(*cryptoStreamImpl)
|
||||
})
|
||||
|
||||
It("sets the read offset", func() {
|
||||
str.setReadOffset(0x42)
|
||||
Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42)))
|
||||
Expect(str.receiveStream.frameQueue.readPos).To(Equal(protocol.ByteCount(0x42)))
|
||||
})
|
||||
})
|
Binary file not shown.
After Width: | Height: | Size: 17 KiB |
Binary file not shown.
|
@ -0,0 +1,65 @@
|
|||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<svg width="601px" height="242px" viewBox="0 0 601 242" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<!-- Generator: Sketch 3.7.2 (28276) - http://www.bohemiancoding.com/sketch -->
|
||||
<title>quic</title>
|
||||
<desc>Created with Sketch.</desc>
|
||||
<defs></defs>
|
||||
<g id="Page-1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<g id="Group-4-Copy" transform="translate(586.087964, 15.755663) rotate(-26.000000) translate(-586.087964, -15.755663) translate(573.036290, 1.253803)">
|
||||
<path d="M12.0494985,28.9428124 C-12.9338918,21.9172827 5.64340536,-9.86582259 25.705693,3.20698278 L12.0494985,28.9428124 L12.0494985,28.9428124 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.861962, 14.514517) scale(-1, 1) translate(-12.861962, -14.514517) "></path>
|
||||
<path d="M12.1335889,20.1646293 C8.83752326,18.4210914 6.41944167,16.0531448 8.43561456,12.2408129 C10.3022111,8.71147453 13.769813,9.09473797 17.066293,10.8382759 L12.1335889,20.1646293 L12.1335889,20.1646293 Z" id="Shape" fill="#000000" transform="translate(12.382769, 14.854037) scale(-1, 1) translate(-12.382769, -14.854037) "></path>
|
||||
</g>
|
||||
<g id="Group-3-Copy" transform="translate(499.346306, 15.541651) rotate(25.000000) translate(-499.346306, -15.541651) translate(486.346306, 0.541651)">
|
||||
<path d="M0.262217125,3.72242032 C20.0373679,-10.1674613 37.9073886,21.170644 14.9248407,29.0555125 L0.262217125,3.72242032 L0.262217125,3.72242032 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.997366, 14.636187) scale(-1, 1) translate(-12.997366, -14.636187) "></path>
|
||||
<path d="M16.7752477,19.7502904 C20.0713133,18.0067525 22.4893949,15.6388059 20.473222,11.8264741 C18.6066255,8.29713567 15.1390235,8.68039911 11.8425436,10.423937 L16.7752477,19.7502904 L16.7752477,19.7502904 Z" id="Shape" fill="#000000" transform="translate(16.526068, 14.439699) scale(-1, 1) translate(-16.526068, -14.439699) "></path>
|
||||
</g>
|
||||
<g id="Group-3" transform="translate(1.896652, 21.000000)">
|
||||
<path d="M0.262217125,3.72242032 C20.0373679,-10.1674613 37.9073886,21.170644 14.9248407,29.0555125 L0.262217125,3.72242032 L0.262217125,3.72242032 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.997366, 14.636187) scale(-1, 1) translate(-12.997366, -14.636187) "></path>
|
||||
<path d="M16.7752477,19.7502904 C20.0713133,18.0067525 22.4893949,15.6388059 20.473222,11.8264741 C18.6066255,8.29713567 15.1390235,8.68039911 11.8425436,10.423937 L16.7752477,19.7502904 L16.7752477,19.7502904 Z" id="Shape" fill="#000000" transform="translate(16.526068, 14.439699) scale(-1, 1) translate(-16.526068, -14.439699) "></path>
|
||||
</g>
|
||||
<g id="Group-4" transform="translate(117.000000, 22.000000)">
|
||||
<path d="M12.0494985,28.9428124 C-12.9338918,21.9172827 5.64340536,-9.86582259 25.705693,3.20698278 L12.0494985,28.9428124 L12.0494985,28.9428124 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#6AD7E5" transform="translate(12.861962, 14.514517) scale(-1, 1) translate(-12.861962, -14.514517) "></path>
|
||||
<path d="M12.1335889,20.1646293 C8.83752326,18.4210914 6.41944167,16.0531448 8.43561456,12.2408129 C10.3022111,8.71147453 13.769813,9.09473797 17.066293,10.8382759 L12.1335889,20.1646293 L12.1335889,20.1646293 Z" id="Shape" fill="#000000" transform="translate(12.382769, 14.854037) scale(-1, 1) translate(-12.382769, -14.854037) "></path>
|
||||
</g>
|
||||
<g id="Group-2" transform="translate(26.000000, 40.000000)">
|
||||
<path d="M0.189217386,22.0262538 C4.67982195,48.130845 47.3500951,41.2287883 41.2211947,14.9008684 C35.725404,-8.70815982 -1.30323119,-2.17320732 0.189217386,22.0262538" id="Shape" stroke="#000000" stroke-width="2.9081" fill="#FFFFFF" transform="translate(20.983709, 19.323749) scale(-1, 1) translate(-20.983709, -19.323749) "></path>
|
||||
<path d="M52.375353,26.2483668 C58.1955709,48.9748532 94.5815666,43.1562926 93.210105,20.3593686 C91.566837,-6.94349061 46.8107821,-1.67517201 52.375353,26.2483668" id="Shape" stroke="#000000" stroke-width="2.8214" fill="#FFFFFF" transform="translate(72.572290, 21.519163) scale(-1, 1) translate(-72.572290, -21.519163) "></path>
|
||||
<path d="M44.873555,53.3511003 C44.8926146,56.7449499 45.6446396,60.55521 45.0028287,64.1657589 C44.1364461,65.803226 42.4368281,65.9764197 40.9717259,66.6381188 C38.9456089,66.3203209 37.2418475,64.9898788 36.429329,63.0946929 C35.9093337,58.9736786 36.6232396,54.9835954 36.754585,50.8609237 L44.873555,53.3511003 L44.873555,53.3511003 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(40.751798, 58.749521) scale(-1, 1) translate(-40.751798, -58.749521) "></path>
|
||||
<g id="Group" transform="translate(82.660603, 22.581468) scale(-1, 1) translate(-82.660603, -22.581468) translate(76.238350, 15.744877)">
|
||||
<ellipse id="Oval" fill="#000000" cx="6.34932869" cy="6.79681466" rx="6.14423095" ry="6.65511077"></ellipse>
|
||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.1440443" cy="8.29879302" rx="1.44852865" ry="1.69133123"></ellipse>
|
||||
</g>
|
||||
<g id="Group" transform="translate(31.282584, 20.924112) scale(-1, 1) translate(-31.282584, -20.924112) translate(24.860332, 14.087521)">
|
||||
<ellipse id="Oval" fill="#000000" cx="6.45167039" cy="6.79681466" rx="6.04188925" ry="6.65511077"></ellipse>
|
||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.19998004" cy="8.29879302" rx="1.424497" ry="1.69133123"></ellipse>
|
||||
</g>
|
||||
<path d="M45.9112876,52.98151 C43.2305151,59.4783433 47.4062222,72.4699383 54.6799409,62.8875235 C54.1599456,58.7665092 54.8738514,54.776426 55.0051969,50.6537543 L45.9112876,52.98151 L45.9112876,52.98151 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(50.049883, 58.475140) scale(-1, 1) translate(-50.049883, -58.475140) "></path>
|
||||
<g id="Group" transform="translate(44.127089, 43.712750) scale(-1, 1) translate(-44.127089, -43.712750) translate(27.346365, 32.732770)">
|
||||
<path d="M7.63667953,7.73694953 C2.64016722,8.16288988 -1.44397093,14.1036805 1.15393372,18.8035261 C4.59418928,25.0285532 12.2731314,18.2528698 17.0558448,18.8876369 C22.5603366,19.0003371 27.0704151,24.7078549 31.4914107,19.9193407 C36.40837,14.593429 29.3745535,9.4063208 23.8771055,7.0872662 L7.63667953,7.73694953 L7.63667953,7.73694953 Z" id="Shape" stroke="#231F20" stroke-width="3" fill="#F6D2A2"></path>
|
||||
<path d="M7.00771314,7.47674473 C6.63770854,-1.1792084 23.1412397,-2.2614615 25.0902897,4.98408215 C27.0343676,12.2126379 7.82023164,13.891539 7.00771314,7.47674473 C6.35927282,2.3542734 7.00771314,7.47674473 7.00771314,7.47674473 L7.00771314,7.47674473 Z" id="Shape" fill="#000000"></path>
|
||||
</g>
|
||||
</g>
|
||||
<g id="Group-2-Copy" transform="translate(544.820291, 73.354278) scale(-1, 1) translate(-544.820291, -73.354278) translate(498.000000, 40.000000)">
|
||||
<path d="M0.189217386,22.0262538 C4.67982195,48.130845 47.3500951,41.2287883 41.2211947,14.9008684 C35.725404,-8.70815982 -1.30323119,-2.17320732 0.189217386,22.0262538" id="Shape" stroke="#000000" stroke-width="2.9081" fill="#FFFFFF" transform="translate(20.983709, 19.323749) scale(-1, 1) translate(-20.983709, -19.323749) "></path>
|
||||
<path d="M52.375353,26.2483668 C58.1955709,48.9748532 94.5815666,43.1562926 93.210105,20.3593686 C91.566837,-6.94349061 46.8107821,-1.67517201 52.375353,26.2483668" id="Shape" stroke="#000000" stroke-width="2.8214" fill="#FFFFFF" transform="translate(72.572290, 21.519163) scale(-1, 1) translate(-72.572290, -21.519163) "></path>
|
||||
<path d="M44.873555,53.3511003 C44.8926146,56.7449499 45.6446396,60.55521 45.0028287,64.1657589 C44.1364461,65.803226 42.4368281,65.9764197 40.9717259,66.6381188 C38.9456089,66.3203209 37.2418475,64.9898788 36.429329,63.0946929 C35.9093337,58.9736786 36.6232396,54.9835954 36.754585,50.8609237 L44.873555,53.3511003 L44.873555,53.3511003 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(40.751798, 58.749521) scale(-1, 1) translate(-40.751798, -58.749521) "></path>
|
||||
<g id="Group" transform="translate(82.660603, 22.581468) scale(-1, 1) translate(-82.660603, -22.581468) translate(76.238350, 15.744877)">
|
||||
<ellipse id="Oval" fill="#000000" cx="6.34932869" cy="6.79681466" rx="6.14423095" ry="6.65511077"></ellipse>
|
||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.1440443" cy="8.29879302" rx="1.44852865" ry="1.69133123"></ellipse>
|
||||
</g>
|
||||
<g id="Group" transform="translate(31.282584, 20.924112) scale(-1, 1) translate(-31.282584, -20.924112) translate(24.860332, 14.087521)">
|
||||
<ellipse id="Oval" fill="#000000" cx="6.45167039" cy="6.79681466" rx="6.04188925" ry="6.65511077"></ellipse>
|
||||
<ellipse id="Oval" fill="#FFFFFF" cx="9.19998004" cy="8.29879302" rx="1.424497" ry="1.69133123"></ellipse>
|
||||
</g>
|
||||
<path d="M45.9112876,52.98151 C43.2305151,59.4783433 47.4062222,72.4699383 54.6799409,62.8875235 C54.1599456,58.7665092 54.8738514,54.776426 55.0051969,50.6537543 L45.9112876,52.98151 L45.9112876,52.98151 Z" id="Shape" stroke="#000000" stroke-width="3" fill="#FFFFFF" transform="translate(50.049883, 58.475140) scale(-1, 1) translate(-50.049883, -58.475140) "></path>
|
||||
<g id="Group" transform="translate(44.127089, 43.712750) scale(-1, 1) translate(-44.127089, -43.712750) translate(27.346365, 32.732770)">
|
||||
<path d="M7.63667953,7.73694953 C2.64016722,8.16288988 -1.44397093,14.1036805 1.15393372,18.8035261 C4.59418928,25.0285532 12.2731314,18.2528698 17.0558448,18.8876369 C22.5603366,19.0003371 27.0704151,24.7078549 31.4914107,19.9193407 C36.40837,14.593429 29.3745535,9.4063208 23.8771055,7.0872662 L7.63667953,7.73694953 L7.63667953,7.73694953 Z" id="Shape" stroke="#231F20" stroke-width="3" fill="#F6D2A2"></path>
|
||||
<path d="M7.00771314,7.47674473 C6.63770854,-1.1792084 23.1412397,-2.2614615 25.0902897,4.98408215 C27.0343676,12.2126379 7.82023164,13.891539 7.00771314,7.47674473 C6.35927282,2.3542734 7.00771314,7.47674473 7.00771314,7.47674473 L7.00771314,7.47674473 Z" id="Shape" fill="#000000"></path>
|
||||
</g>
|
||||
</g>
|
||||
<text id="QUIC" font-family="SourceCodePro-Light, Source Code Pro" font-size="288" font-weight="300" letter-spacing="-20" fill="#000000">
|
||||
<tspan x="-13.6" y="197">QUI</tspan>
|
||||
<tspan x="444.8" y="197">C</tspan>
|
||||
</text>
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 11 KiB |
|
@ -0,0 +1,9 @@
|
|||
FROM scratch
|
||||
|
||||
VOLUME /certs
|
||||
VOLUME /www
|
||||
EXPOSE 6121
|
||||
|
||||
ADD main /main
|
||||
|
||||
CMD ["/main", "-bind=0.0.0.0", "-certpath=/certs/", "-www=/www"]
|
|
@ -0,0 +1,7 @@
|
|||
# About the certificate
|
||||
|
||||
Yes, this folder contains a private key and a certificate for quic.clemente.io.
|
||||
|
||||
Unfortunately we need a valid certificate for the integration tests with Chrome and `quic_client`. No important data is served on the "real" `quic.clemente.io` (only a test page), and the MITM problem is imho negligible.
|
||||
|
||||
If you figure out a way to test with Chrome without having a cert and key here, let us now in an issue.
|
|
@ -0,0 +1,66 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
func main() {
|
||||
verbose := flag.Bool("v", false, "verbose")
|
||||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||
flag.Parse()
|
||||
urls := flag.Args()
|
||||
|
||||
logger := utils.DefaultLogger
|
||||
|
||||
if *verbose {
|
||||
logger.SetLogLevel(utils.LogLevelDebug)
|
||||
} else {
|
||||
logger.SetLogLevel(utils.LogLevelInfo)
|
||||
}
|
||||
logger.SetLogTimeFormat("")
|
||||
|
||||
versions := protocol.SupportedVersions
|
||||
if *tls {
|
||||
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||
}
|
||||
|
||||
roundTripper := &h2quic.RoundTripper{
|
||||
QuicConfig: &quic.Config{Versions: versions},
|
||||
}
|
||||
defer roundTripper.Close()
|
||||
hclient := &http.Client{
|
||||
Transport: roundTripper,
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(urls))
|
||||
for _, addr := range urls {
|
||||
logger.Infof("GET %s", addr)
|
||||
go func(addr string) {
|
||||
rsp, err := hclient.Get(addr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
logger.Infof("Got response for %s: %#v", addr, rsp)
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
_, err = io.Copy(body, rsp.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
logger.Infof("Request Body:")
|
||||
logger.Infof("%s", body.Bytes())
|
||||
wg.Done()
|
||||
}(addr)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
const addr = "localhost:4242"
|
||||
|
||||
const message = "foobar"
|
||||
|
||||
// We start a server echoing data on the first stream the client opens,
|
||||
// then connect with a client, send the message, and wait for its receipt.
|
||||
func main() {
|
||||
go func() { log.Fatal(echoServer()) }()
|
||||
|
||||
err := clientMain()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start a server that echos all data on the first stream opened by the client
|
||||
func echoServer() error {
|
||||
listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sess, err := listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Echo through the loggingWriter
|
||||
_, err = io.Copy(loggingWriter{stream}, stream)
|
||||
return err
|
||||
}
|
||||
|
||||
func clientMain() error {
|
||||
session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := session.OpenStreamSync()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Client: Sending '%s'\n", message)
|
||||
_, err = stream.Write([]byte(message))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, len(message))
|
||||
_, err = io.ReadFull(stream, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Client: Got '%s'\n", buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// A wrapper for io.Writer that also logs the message.
|
||||
type loggingWriter struct{ io.Writer }
|
||||
|
||||
func (w loggingWriter) Write(b []byte) (int, error) {
|
||||
fmt.Printf("Server: Got '%s'\n", string(b))
|
||||
return w.Writer.Write(b)
|
||||
}
|
||||
|
||||
// Setup a bare-bones TLS config for the server
|
||||
func generateTLSConfig() *tls.Config {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
template := x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &tls.Config{Certificates: []tls.Certificate{tlsCert}}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
-----BEGIN CERTIFICATE-----
|
||||
MIIGCzCCBPOgAwIBAgISA6pwet1vq9IjS9ThcggZYGfyMA0GCSqGSIb3DQEBCwUA
|
||||
MEoxCzAJBgNVBAYTAlVTMRYwFAYDVQQKEw1MZXQncyBFbmNyeXB0MSMwIQYDVQQD
|
||||
ExpMZXQncyBFbmNyeXB0IEF1dGhvcml0eSBYMzAeFw0xODA2MDkwODI2MjVaFw0x
|
||||
ODA5MDcwODI2MjVaMBsxGTAXBgNVBAMTEHF1aWMuY2xlbWVudGUuaW8wggEiMA0G
|
||||
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDqON9RKCMK4dHxLTLiv3n+b/v4KCbY
|
||||
Fo38rxCcv4/qoWI1Zoz4XMyPnTZPrzVMz6rhDKWTve1v7g+itgQDp1OoRo5ihtqC
|
||||
m9Dr1Ed+qb3hHtFRBBOTiWQy1Y0fvUDYGlSHQ4R7xXYVzCYkQ9zopqr2otJp/9ZA
|
||||
1Yy3ATITSgcds9wJaAOpkSxCx+D7cqpGbaNtxogRZ4ZT1vVs3l7Kr26fvnALMM4K
|
||||
nf2Rpq1ZjgPRoLNcfUWCZO7gr+VNCFaI2AQRaAzZkSXAk0B91waTtaFD87kYXl9Z
|
||||
8Jgx1PhJnO8WBUz/7+OAO6M+kuT78e6FltMN004Q8WmuNww6PzmVwnmpAgMBAAGj
|
||||
ggMYMIIDFDAOBgNVHQ8BAf8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG
|
||||
AQUFBwMCMAwGA1UdEwEB/wQCMAAwHQYDVR0OBBYEFIs3yJaLnEw7MRtNf9EbzESy
|
||||
TrHwMB8GA1UdIwQYMBaAFKhKamMEfd265tE5t6ZFZe/zqOyhMG8GCCsGAQUFBwEB
|
||||
BGMwYTAuBggrBgEFBQcwAYYiaHR0cDovL29jc3AuaW50LXgzLmxldHNlbmNyeXB0
|
||||
Lm9yZzAvBggrBgEFBQcwAoYjaHR0cDovL2NlcnQuaW50LXgzLmxldHNlbmNyeXB0
|
||||
Lm9yZy8wGwYDVR0RBBQwEoIQcXVpYy5jbGVtZW50ZS5pbzCB/gYDVR0gBIH2MIHz
|
||||
MAgGBmeBDAECATCB5gYLKwYBBAGC3xMBAQEwgdYwJgYIKwYBBQUHAgEWGmh0dHA6
|
||||
Ly9jcHMubGV0c2VuY3J5cHQub3JnMIGrBggrBgEFBQcCAjCBngyBm1RoaXMgQ2Vy
|
||||
dGlmaWNhdGUgbWF5IG9ubHkgYmUgcmVsaWVkIHVwb24gYnkgUmVseWluZyBQYXJ0
|
||||
aWVzIGFuZCBvbmx5IGluIGFjY29yZGFuY2Ugd2l0aCB0aGUgQ2VydGlmaWNhdGUg
|
||||
UG9saWN5IGZvdW5kIGF0IGh0dHBzOi8vbGV0c2VuY3J5cHQub3JnL3JlcG9zaXRv
|
||||
cnkvMIIBBAYKKwYBBAHWeQIEAgSB9QSB8gDwAHYA23Sv7ssp7LH+yj5xbSzluaq7
|
||||
NveEcYPHXZ1PN7Yfv2QAAAFj495INQAABAMARzBFAiEApF0BFCWyGIUrJrsYuugt
|
||||
tshGVdg2+7f6d4B1D/xF2s0CIGGsVL2nRlXJXrkk3aa83lH4HzP9vcQSnMFHdXOf
|
||||
9XeZAHYAKTxRllTIOWW6qlD8WAfUt2+/WHopctykwwz05UVH9HgAAAFj495IQwAA
|
||||
BAMARzBFAiEAkV4gbM6hucL7ZwqTzb5fKxYhk6WHr5y8pzClZD3qqv4CIBYH7MSA
|
||||
P05CXv7tHiHEizIvhWJJvVa2E6XLjbDRQMnUMA0GCSqGSIb3DQEBCwUAA4IBAQA2
|
||||
CALjtlxGXkkfKsRgKDoPpzg/IAl7crq5OrGGwW/bxbeDeiRHVt0Hlhr+0XPYlh/A
|
||||
m8qBlg7TMHJa2zIt7wkG/MX3d9bO2bxXxsdjfMXLtxQu7eZ5nzGuXL8WGnZwtSqz
|
||||
M5pOF/AU1JQojIhehKCeqqUi5UobxXUm+9D6OVmr8s732X3n/TL6pgsWRhay9tjB
|
||||
kdZ9TSe0tLYyXnbwHp0rgwNWOMMN1Tc+Fpqc8UlrCq5REb0bLIQ6A2IJH08MEPWG
|
||||
ukXHPAaHLz9oB1emR3flCoMKB0KrXUpDFXemOIPN6QVBO16LNNuRffKwXnzp60+b
|
||||
MoB4Krxrab7TzlfT+HnP
|
||||
-----END CERTIFICATE-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIEkjCCA3qgAwIBAgIQCgFBQgAAAVOFc2oLheynCDANBgkqhkiG9w0BAQsFADA/
|
||||
MSQwIgYDVQQKExtEaWdpdGFsIFNpZ25hdHVyZSBUcnVzdCBDby4xFzAVBgNVBAMT
|
||||
DkRTVCBSb290IENBIFgzMB4XDTE2MDMxNzE2NDA0NloXDTIxMDMxNzE2NDA0Nlow
|
||||
SjELMAkGA1UEBhMCVVMxFjAUBgNVBAoTDUxldCdzIEVuY3J5cHQxIzAhBgNVBAMT
|
||||
GkxldCdzIEVuY3J5cHQgQXV0aG9yaXR5IFgzMIIBIjANBgkqhkiG9w0BAQEFAAOC
|
||||
AQ8AMIIBCgKCAQEAnNMM8FrlLke3cl03g7NoYzDq1zUmGSXhvb418XCSL7e4S0EF
|
||||
q6meNQhY7LEqxGiHC6PjdeTm86dicbp5gWAf15Gan/PQeGdxyGkOlZHP/uaZ6WA8
|
||||
SMx+yk13EiSdRxta67nsHjcAHJyse6cF6s5K671B5TaYucv9bTyWaN8jKkKQDIZ0
|
||||
Z8h/pZq4UmEUEz9l6YKHy9v6Dlb2honzhT+Xhq+w3Brvaw2VFn3EK6BlspkENnWA
|
||||
a6xK8xuQSXgvopZPKiAlKQTGdMDQMc2PMTiVFrqoM7hD8bEfwzB/onkxEz0tNvjj
|
||||
/PIzark5McWvxI0NHWQWM6r6hCm21AvA2H3DkwIDAQABo4IBfTCCAXkwEgYDVR0T
|
||||
AQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAYYwfwYIKwYBBQUHAQEEczBxMDIG
|
||||
CCsGAQUFBzABhiZodHRwOi8vaXNyZy50cnVzdGlkLm9jc3AuaWRlbnRydXN0LmNv
|
||||
bTA7BggrBgEFBQcwAoYvaHR0cDovL2FwcHMuaWRlbnRydXN0LmNvbS9yb290cy9k
|
||||
c3Ryb290Y2F4My5wN2MwHwYDVR0jBBgwFoAUxKexpHsscfrb4UuQdf/EFWCFiRAw
|
||||
VAYDVR0gBE0wSzAIBgZngQwBAgEwPwYLKwYBBAGC3xMBAQEwMDAuBggrBgEFBQcC
|
||||
ARYiaHR0cDovL2Nwcy5yb290LXgxLmxldHNlbmNyeXB0Lm9yZzA8BgNVHR8ENTAz
|
||||
MDGgL6AthitodHRwOi8vY3JsLmlkZW50cnVzdC5jb20vRFNUUk9PVENBWDNDUkwu
|
||||
Y3JsMB0GA1UdDgQWBBSoSmpjBH3duubRObemRWXv86jsoTANBgkqhkiG9w0BAQsF
|
||||
AAOCAQEA3TPXEfNjWDjdGBX7CVW+dla5cEilaUcne8IkCJLxWh9KEik3JHRRHGJo
|
||||
uM2VcGfl96S8TihRzZvoroed6ti6WqEBmtzw3Wodatg+VyOeph4EYpr/1wXKtx8/
|
||||
wApIvJSwtmVi4MFU5aMqrSDE6ea73Mj2tcMyo5jMd6jmeWUHK8so/joWUoHOUgwu
|
||||
X4Po1QYz+3dszkDqMp4fklxBwXRsW10KXzPMTZ+sOPAveyxindmjkW8lGy+QsRlG
|
||||
PfZ+G6Z6h7mjem0Y+iWlkYcV4PIWL1iwBi8saCbGS5jN2p8M+X+Q7UNKEkROb3N6
|
||||
KOqkqm57TH2H3eDJAkSnh6/DNFu0Qg==
|
||||
-----END CERTIFICATE-----
|
|
@ -0,0 +1,174 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
_ "net/http/pprof"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type binds []string
|
||||
|
||||
func (b binds) String() string {
|
||||
return strings.Join(b, ",")
|
||||
}
|
||||
|
||||
func (b *binds) Set(v string) error {
|
||||
*b = strings.Split(v, ",")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Size is needed by the /demo/upload handler to determine the size of the uploaded file
|
||||
type Size interface {
|
||||
Size() int64
|
||||
}
|
||||
|
||||
func init() {
|
||||
http.HandleFunc("/demo/tile", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Small 40x40 png
|
||||
w.Write([]byte{
|
||||
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d,
|
||||
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x28,
|
||||
0x01, 0x03, 0x00, 0x00, 0x00, 0xb6, 0x30, 0x2a, 0x2e, 0x00, 0x00, 0x00,
|
||||
0x03, 0x50, 0x4c, 0x54, 0x45, 0x5a, 0xc3, 0x5a, 0xad, 0x38, 0xaa, 0xdb,
|
||||
0x00, 0x00, 0x00, 0x0b, 0x49, 0x44, 0x41, 0x54, 0x78, 0x01, 0x63, 0x18,
|
||||
0x61, 0x00, 0x00, 0x00, 0xf0, 0x00, 0x01, 0xe2, 0xb8, 0x75, 0x22, 0x00,
|
||||
0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae, 0x42, 0x60, 0x82,
|
||||
})
|
||||
})
|
||||
|
||||
http.HandleFunc("/demo/tiles", func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, "<html><head><style>img{width:40px;height:40px;}</style></head><body>")
|
||||
for i := 0; i < 200; i++ {
|
||||
fmt.Fprintf(w, `<img src="/demo/tile?cachebust=%d">`, i)
|
||||
}
|
||||
io.WriteString(w, "</body></html>")
|
||||
})
|
||||
|
||||
http.HandleFunc("/demo/echo", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("error reading body while handling /echo: %s\n", err.Error())
|
||||
}
|
||||
w.Write(body)
|
||||
})
|
||||
|
||||
// accept file uploads and return the MD5 of the uploaded file
|
||||
// maximum accepted file size is 1 GB
|
||||
http.HandleFunc("/demo/upload", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost {
|
||||
err := r.ParseMultipartForm(1 << 30) // 1 GB
|
||||
if err == nil {
|
||||
var file multipart.File
|
||||
file, _, err = r.FormFile("uploadfile")
|
||||
if err == nil {
|
||||
var size int64
|
||||
if sizeInterface, ok := file.(Size); ok {
|
||||
size = sizeInterface.Size()
|
||||
b := make([]byte, size)
|
||||
file.Read(b)
|
||||
md5 := md5.Sum(b)
|
||||
fmt.Fprintf(w, "%x", md5)
|
||||
return
|
||||
}
|
||||
err = errors.New("couldn't get uploaded file size")
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
utils.DefaultLogger.Infof("Error receiving upload: %#v", err)
|
||||
}
|
||||
}
|
||||
io.WriteString(w, `<html><body><form action="/demo/upload" method="post" enctype="multipart/form-data">
|
||||
<input type="file" name="uploadfile"><br>
|
||||
<input type="submit">
|
||||
</form></body></html>`)
|
||||
})
|
||||
}
|
||||
|
||||
func getBuildDir() string {
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
panic("Failed to get current frame")
|
||||
}
|
||||
|
||||
return path.Dir(filename)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// defer profile.Start().Stop()
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
// runtime.SetBlockProfileRate(1)
|
||||
|
||||
verbose := flag.Bool("v", false, "verbose")
|
||||
bs := binds{}
|
||||
flag.Var(&bs, "bind", "bind to")
|
||||
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
|
||||
www := flag.String("www", "/var/www", "www data")
|
||||
tcp := flag.Bool("tcp", false, "also listen on TCP")
|
||||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||
flag.Parse()
|
||||
|
||||
logger := utils.DefaultLogger
|
||||
|
||||
if *verbose {
|
||||
logger.SetLogLevel(utils.LogLevelDebug)
|
||||
} else {
|
||||
logger.SetLogLevel(utils.LogLevelInfo)
|
||||
}
|
||||
logger.SetLogTimeFormat("")
|
||||
|
||||
versions := protocol.SupportedVersions
|
||||
if *tls {
|
||||
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||
}
|
||||
|
||||
certFile := *certPath + "/fullchain.pem"
|
||||
keyFile := *certPath + "/privkey.pem"
|
||||
|
||||
http.Handle("/", http.FileServer(http.Dir(*www)))
|
||||
|
||||
if len(bs) == 0 {
|
||||
bs = binds{"localhost:6121"}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(bs))
|
||||
for _, b := range bs {
|
||||
bCap := b
|
||||
go func() {
|
||||
var err error
|
||||
if *tcp {
|
||||
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
|
||||
} else {
|
||||
server := h2quic.Server{
|
||||
Server: &http.Server{Addr: bCap},
|
||||
QuicConfig: &quic.Config{Versions: versions},
|
||||
}
|
||||
err = server.ListenAndServeTLS(certFile, keyFile)
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDqON9RKCMK4dHx
|
||||
LTLiv3n+b/v4KCbYFo38rxCcv4/qoWI1Zoz4XMyPnTZPrzVMz6rhDKWTve1v7g+i
|
||||
tgQDp1OoRo5ihtqCm9Dr1Ed+qb3hHtFRBBOTiWQy1Y0fvUDYGlSHQ4R7xXYVzCYk
|
||||
Q9zopqr2otJp/9ZA1Yy3ATITSgcds9wJaAOpkSxCx+D7cqpGbaNtxogRZ4ZT1vVs
|
||||
3l7Kr26fvnALMM4Knf2Rpq1ZjgPRoLNcfUWCZO7gr+VNCFaI2AQRaAzZkSXAk0B9
|
||||
1waTtaFD87kYXl9Z8Jgx1PhJnO8WBUz/7+OAO6M+kuT78e6FltMN004Q8WmuNww6
|
||||
PzmVwnmpAgMBAAECggEAAW7hpux48msZTsF5CzwisfTbdNRCEJZqvf4QOvVNGyFr
|
||||
qWn8ONTQh5xtpaUrzVGD+SaLqNDDsCijvdohQih28ZOk8WNj2OK9L4Q3/8VoHQWE
|
||||
QFunBwMTMuBtoaEV0XyvwbgfCmbV5yI9pYEoy9+hMisi4HUpSXJFDyWZudZ9Hqhl
|
||||
FBf5UPxF1AjZViODVfnKkrWh85jaENyWjWrMEDSohzxP8IStFMK8E0feXaQb+G0f
|
||||
uTelpG2JmVcO3xaYgtRVeS4p5liMQg5gE5oa2Jh7Vp3TwMhQVg0aLmslSLAYlPoh
|
||||
hyBeOS3ucyFHoC/6Stnnx3jdOEf2lEUObJj3QVEeBQKBgQD3qptNY9R62UQFO8gT
|
||||
pseEO5CAZRGuKG0VLPNqKKerYQreiT3xYOTPmwEy+xXzYV6DKHtlKDetHuOB862M
|
||||
E1bKmjDedafQ3KMc4tywLW+U9I8GyooT8uoetzARzgoftRcMzdeu4fexWTsi+kOh
|
||||
5/PeXUBnFph8E68nXHQR5+MWAwKBgQDyGnT9KUVuLvzrFAghe9Eu1iZJSDs/IrJj
|
||||
we3XQ7loqMc/Qv34eVKATsgtb2cTSeTivQcSvQO+Uu9dgo17BoWAABDTaV8NAOZ6
|
||||
cV2kedrWnxaTRXzB6Z5EKLg+DOMTpCV4Nf3OwzGq/mnKAe8cNM/hS+8HcZywKwr2
|
||||
UwLKSq6n4wKBgQDXGUaOnVCSbZZlETnAz43i67ShvqXvY07yIDs8jRiqgLrm8b1p
|
||||
oaS4JkCRXX8ABSYHtaYOAjLw2a3wVIn66WTsy6P74aWhga7szJ+tJ5kMfqal2Ey5
|
||||
7LSnfqRyIkeqqCXfyfsz+S+dyQjSZRdOS90C2Gyx2+8NfC8YeXSZhJM2rwKBgQDB
|
||||
pL28G/2nsrejQ2N5fLKE9s6qwLZ6ukLbHasiCc5L0uuDQw8mZcvCSsE77iYQvILx
|
||||
hGYa68oJugYw0hJdu4qeJe9PWbGoEfdHKlPPEZQjJB4Hb4XpB/YJ6FPtdZtPA3Tg
|
||||
4LaAYYnhjhqJc+CPvAIl3vlyB8Je+h6LhTvvF6r5JwKBgHBkpQ2Y+N3dMgfE8T38
|
||||
+Vy+75i9ew3F0TB8o2MWJzYLYCzNLqrB4GkZieudoYP1Cmh50wYIifsk2zNpxOG0
|
||||
27tUfMxza4ITUfmtVHS22DGWCqb1TgfQt7jajdioK4TUOZFEmiQiOYMZoFMenbzZ
|
||||
miz6s3fWx1/b3wNd1J5uV2QS
|
||||
-----END PRIVATE KEY-----
|
|
@ -0,0 +1,158 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type frameSorter struct {
|
||||
queue map[protocol.ByteCount][]byte
|
||||
readPos protocol.ByteCount
|
||||
finalOffset protocol.ByteCount
|
||||
gaps *utils.ByteIntervalList
|
||||
}
|
||||
|
||||
var errDuplicateStreamData = errors.New("Duplicate Stream Data")
|
||||
|
||||
func newFrameSorter() *frameSorter {
|
||||
s := frameSorter{
|
||||
gaps: utils.NewByteIntervalList(),
|
||||
queue: make(map[protocol.ByteCount][]byte),
|
||||
finalOffset: protocol.MaxByteCount,
|
||||
}
|
||||
s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount})
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, fin bool) error {
|
||||
err := s.push(data, offset, fin)
|
||||
if err == errDuplicateStreamData {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, fin bool) error {
|
||||
if fin {
|
||||
s.finalOffset = offset + protocol.ByteCount(len(data))
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var wasCut bool
|
||||
if oldData, ok := s.queue[offset]; ok {
|
||||
if len(data) <= len(oldData) {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
data = data[len(oldData):]
|
||||
offset += protocol.ByteCount(len(oldData))
|
||||
wasCut = true
|
||||
}
|
||||
|
||||
start := offset
|
||||
end := offset + protocol.ByteCount(len(data))
|
||||
|
||||
// skip all gaps that are before this stream frame
|
||||
var gap *utils.ByteIntervalElement
|
||||
for gap = s.gaps.Front(); gap != nil; gap = gap.Next() {
|
||||
// the frame is a duplicate. Ignore it
|
||||
if end <= gap.Value.Start {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
if end > gap.Value.Start && start <= gap.Value.End {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if gap == nil {
|
||||
return errors.New("StreamFrameSorter BUG: no gap found")
|
||||
}
|
||||
|
||||
if start < gap.Value.Start {
|
||||
add := gap.Value.Start - start
|
||||
offset += add
|
||||
start += add
|
||||
data = data[add:]
|
||||
wasCut = true
|
||||
}
|
||||
|
||||
// find the highest gaps whose Start lies before the end of the frame
|
||||
endGap := gap
|
||||
for end >= endGap.Value.End {
|
||||
nextEndGap := endGap.Next()
|
||||
if nextEndGap == nil {
|
||||
return errors.New("StreamFrameSorter BUG: no end gap found")
|
||||
}
|
||||
if endGap != gap {
|
||||
s.gaps.Remove(endGap)
|
||||
}
|
||||
if end <= nextEndGap.Value.Start {
|
||||
break
|
||||
}
|
||||
// delete queued frames completely covered by the current frame
|
||||
delete(s.queue, endGap.Value.End)
|
||||
endGap = nextEndGap
|
||||
}
|
||||
|
||||
if end > endGap.Value.End {
|
||||
cutLen := end - endGap.Value.End
|
||||
len := protocol.ByteCount(len(data)) - cutLen
|
||||
end -= cutLen
|
||||
data = data[:len]
|
||||
wasCut = true
|
||||
}
|
||||
|
||||
if start == gap.Value.Start {
|
||||
if end >= gap.Value.End {
|
||||
// the frame completely fills this gap
|
||||
// delete the gap
|
||||
s.gaps.Remove(gap)
|
||||
}
|
||||
if end < endGap.Value.End {
|
||||
// the frame covers the beginning of the gap
|
||||
// adjust the Start value to shrink the gap
|
||||
endGap.Value.Start = end
|
||||
}
|
||||
} else if end == endGap.Value.End {
|
||||
// the frame covers the end of the gap
|
||||
// adjust the End value to shrink the gap
|
||||
gap.Value.End = start
|
||||
} else {
|
||||
if gap == endGap {
|
||||
// the frame lies within the current gap, splitting it into two
|
||||
// insert a new gap and adjust the current one
|
||||
intv := utils.ByteInterval{Start: end, End: gap.Value.End}
|
||||
s.gaps.InsertAfter(intv, gap)
|
||||
gap.Value.End = start
|
||||
} else {
|
||||
gap.Value.End = start
|
||||
endGap.Value.Start = end
|
||||
}
|
||||
}
|
||||
|
||||
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
|
||||
return errors.New("Too many gaps in received data")
|
||||
}
|
||||
|
||||
if wasCut {
|
||||
newData := make([]byte, len(data))
|
||||
copy(newData, data)
|
||||
data = newData
|
||||
}
|
||||
|
||||
s.queue[offset] = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
|
||||
data, ok := s.queue[s.readPos]
|
||||
if !ok {
|
||||
return nil, s.readPos >= s.finalOffset
|
||||
}
|
||||
delete(s.queue, s.readPos)
|
||||
s.readPos += protocol.ByteCount(len(data))
|
||||
return data, s.readPos >= s.finalOffset
|
||||
}
|
|
@ -0,0 +1,435 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("STREAM frame sorter", func() {
|
||||
var s *frameSorter
|
||||
|
||||
checkGaps := func(expectedGaps []utils.ByteInterval) {
|
||||
Expect(s.gaps.Len()).To(Equal(len(expectedGaps)))
|
||||
var i int
|
||||
for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
|
||||
Expect(gap.Value).To(Equal(expectedGaps[i]))
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
s = newFrameSorter()
|
||||
})
|
||||
|
||||
It("head returns nil when empty", func() {
|
||||
Expect(s.Pop()).To(BeNil())
|
||||
})
|
||||
|
||||
Context("Push", func() {
|
||||
It("inserts and pops a single frame", func() {
|
||||
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
|
||||
data, fin := s.Pop()
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
Expect(fin).To(BeFalse())
|
||||
Expect(s.Pop()).To(BeNil())
|
||||
})
|
||||
|
||||
It("inserts and pops two consecutive frame", func() {
|
||||
Expect(s.Push([]byte("foo"), 0, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("bar"), 3, false)).To(Succeed())
|
||||
data, fin := s.Pop()
|
||||
Expect(data).To(Equal([]byte("foo")))
|
||||
Expect(fin).To(BeFalse())
|
||||
data, fin = s.Pop()
|
||||
Expect(data).To(Equal([]byte("bar")))
|
||||
Expect(fin).To(BeFalse())
|
||||
Expect(s.Pop()).To(BeNil())
|
||||
})
|
||||
|
||||
It("ignores empty frames", func() {
|
||||
Expect(s.Push(nil, 0, false)).To(Succeed())
|
||||
Expect(s.Pop()).To(BeNil())
|
||||
})
|
||||
|
||||
Context("FIN handling", func() {
|
||||
It("saves a FIN at offset 0", func() {
|
||||
Expect(s.Push(nil, 0, true)).To(Succeed())
|
||||
data, fin := s.Pop()
|
||||
Expect(data).To(BeEmpty())
|
||||
Expect(fin).To(BeTrue())
|
||||
data, fin = s.Pop()
|
||||
Expect(data).To(BeNil())
|
||||
Expect(fin).To(BeTrue())
|
||||
})
|
||||
|
||||
It("saves a FIN frame at non-zero offset", func() {
|
||||
Expect(s.Push([]byte("foobar"), 0, true)).To(Succeed())
|
||||
data, fin := s.Pop()
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
Expect(fin).To(BeTrue())
|
||||
data, fin = s.Pop()
|
||||
Expect(data).To(BeNil())
|
||||
Expect(fin).To(BeTrue())
|
||||
})
|
||||
|
||||
It("sets the FIN if a stream is closed after receiving some data", func() {
|
||||
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
|
||||
Expect(s.Push(nil, 6, true)).To(Succeed())
|
||||
data, fin := s.Pop()
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
Expect(fin).To(BeTrue())
|
||||
data, fin = s.Pop()
|
||||
Expect(data).To(BeNil())
|
||||
Expect(fin).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Gap handling", func() {
|
||||
It("finds the first gap", func() {
|
||||
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 10},
|
||||
{Start: 16, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("correctly sets the first gap for a frame with offset 0", func() {
|
||||
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 6, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("finds the two gaps", func() {
|
||||
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("foobar"), 20, false)).To(Succeed())
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 10},
|
||||
{Start: 16, End: 20},
|
||||
{Start: 26, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("finds the two gaps in reverse order", func() {
|
||||
Expect(s.Push([]byte("foobar"), 20, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 10},
|
||||
{Start: 16, End: 20},
|
||||
{Start: 26, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("shrinks a gap when it is partially filled", func() {
|
||||
Expect(s.Push([]byte("test"), 10, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("foobar"), 4, false)).To(Succeed())
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 4},
|
||||
{Start: 14, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("deletes a gap at the beginning, when it is filled", func() {
|
||||
Expect(s.Push([]byte("test"), 6, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 10, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("deletes a gap in the middle, when it is filled", func() {
|
||||
Expect(s.Push([]byte("test"), 0, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("test2"), 10, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("foobar"), 4, false)).To(Succeed())
|
||||
Expect(s.queue).To(HaveLen(3))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 15, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("splits a gap into two", func() {
|
||||
Expect(s.Push([]byte("test"), 100, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("foobar"), 50, false)).To(Succeed())
|
||||
Expect(s.queue).To(HaveLen(2))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 50},
|
||||
{Start: 56, End: 100},
|
||||
{Start: 104, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
Context("Overlapping Stream Data detection", func() {
|
||||
// create gaps: 0-5, 10-15, 20-25, 30-inf
|
||||
BeforeEach(func() {
|
||||
Expect(s.Push([]byte("12345"), 5, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("12345"), 15, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("12345"), 25, false)).To(Succeed())
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 5},
|
||||
{Start: 10, End: 15},
|
||||
{Start: 20, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("cuts a frame with offset 0 that overlaps at the end", func() {
|
||||
Expect(s.Push([]byte("foobar"), 0, false)).To(Succeed())
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(0)))
|
||||
Expect(s.queue[0]).To(Equal([]byte("fooba")))
|
||||
Expect(s.queue[0]).To(HaveCap(5))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 10, End: 15},
|
||||
{Start: 20, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("cuts a frame that overlaps at the end", func() {
|
||||
// 4 to 7
|
||||
Expect(s.Push([]byte("foo"), 4, false)).To(Succeed())
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(4)))
|
||||
Expect(s.queue[4]).To(Equal([]byte("f")))
|
||||
Expect(s.queue[4]).To(HaveCap(1))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 4},
|
||||
{Start: 10, End: 15},
|
||||
{Start: 20, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("cuts a frame that completely fills a gap, but overlaps at the end", func() {
|
||||
// 10 to 16
|
||||
Expect(s.Push([]byte("foobar"), 10, false)).To(Succeed())
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
|
||||
Expect(s.queue[10]).To(Equal([]byte("fooba")))
|
||||
Expect(s.queue[10]).To(HaveCap(5))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 5},
|
||||
{Start: 20, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("cuts a frame that overlaps at the beginning", func() {
|
||||
// 8 to 14
|
||||
Expect(s.Push([]byte("foobar"), 8, false)).To(Succeed())
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(8)))
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
|
||||
Expect(s.queue[10]).To(Equal([]byte("obar")))
|
||||
Expect(s.queue[10]).To(HaveCap(4))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 5},
|
||||
{Start: 14, End: 15},
|
||||
{Start: 20, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("processes a frame that overlaps at the beginning and at the end, starting in a gap", func() {
|
||||
// 2 to 12
|
||||
Expect(s.Push([]byte("1234567890"), 2, false)).To(Succeed())
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
|
||||
Expect(s.queue[2]).To(Equal([]byte("1234567890")))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 2},
|
||||
{Start: 12, End: 15},
|
||||
{Start: 20, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("processes a frame that overlaps at the beginning and at the end, starting in a gap, ending in data", func() {
|
||||
// 2 to 17
|
||||
Expect(s.Push([]byte("123456789012345"), 2, false)).To(Succeed())
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
|
||||
Expect(s.queue[2]).To(Equal([]byte("1234567890123")))
|
||||
Expect(s.queue[2]).To(HaveCap(13))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 2},
|
||||
{Start: 20, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("processes a frame that overlaps at the beginning and at the end, starting in a gap, ending in data", func() {
|
||||
// 5 to 22
|
||||
Expect(s.Push([]byte("12345678901234567"), 5, false)).To(Succeed())
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(5)))
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
|
||||
Expect(s.queue[10]).To(Equal([]byte("678901234567")))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 5},
|
||||
{Start: 22, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("processes a frame that closes multiple gaps", func() {
|
||||
// 2 to 27
|
||||
Expect(s.Push(bytes.Repeat([]byte{'e'}, 25), 2, false)).To(Succeed())
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(25)))
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(2)))
|
||||
Expect(s.queue[2]).To(Equal(bytes.Repeat([]byte{'e'}, 23)))
|
||||
Expect(s.queue[2]).To(HaveCap(23))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 2},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("processes a frame that closes multiple gaps", func() {
|
||||
// 5 to 27
|
||||
Expect(s.Push(bytes.Repeat([]byte{'d'}, 22), 5, false)).To(Succeed())
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(5)))
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(15)))
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(25)))
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
|
||||
Expect(s.queue[10]).To(Equal(bytes.Repeat([]byte{'d'}, 15)))
|
||||
Expect(s.queue[10]).To(HaveCap(15))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 5},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("processes a frame that covers multiple gaps and ends at the end of a gap", func() {
|
||||
data := bytes.Repeat([]byte{'e'}, 14)
|
||||
// 1 to 15
|
||||
Expect(s.Push(data, 1, false)).To(Succeed())
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(1)))
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(15)))
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(5)))
|
||||
Expect(s.queue[1]).To(Equal(data))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 1},
|
||||
{Start: 20, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("processes a frame that closes all gaps (except for the last one)", func() {
|
||||
data := bytes.Repeat([]byte{'f'}, 32)
|
||||
// 0 to 32
|
||||
Expect(s.Push(data, 0, false)).To(Succeed())
|
||||
Expect(s.queue).To(HaveLen(1))
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(0)))
|
||||
Expect(s.queue[0]).To(Equal(data))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 32, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("cuts a frame that overlaps at the beginning and at the end, starting in data already received", func() {
|
||||
// 8 to 17
|
||||
Expect(s.Push([]byte("123456789"), 8, false)).To(Succeed())
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(8)))
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
|
||||
Expect(s.queue[10]).To(Equal([]byte("34567")))
|
||||
Expect(s.queue[10]).To(HaveCap(5))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 5},
|
||||
{Start: 20, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
|
||||
It("cuts a frame that completely covers two gaps", func() {
|
||||
// 10 to 20
|
||||
Expect(s.Push([]byte("1234567890"), 10, false)).To(Succeed())
|
||||
Expect(s.queue).To(HaveKey(protocol.ByteCount(10)))
|
||||
Expect(s.queue[10]).To(Equal([]byte("12345")))
|
||||
Expect(s.queue[10]).To(HaveCap(5))
|
||||
checkGaps([]utils.ByteInterval{
|
||||
{Start: 0, End: 5},
|
||||
{Start: 20, End: 25},
|
||||
{Start: 30, End: protocol.MaxByteCount},
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("duplicate data", func() {
|
||||
expectedGaps := []utils.ByteInterval{
|
||||
{Start: 5, End: 10},
|
||||
{Start: 15, End: protocol.MaxByteCount},
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
// create gaps: 5-10, 15-inf
|
||||
Expect(s.Push([]byte("12345"), 0, false)).To(Succeed())
|
||||
Expect(s.Push([]byte("12345"), 10, false)).To(Succeed())
|
||||
checkGaps(expectedGaps)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// check that the gaps were not modified
|
||||
checkGaps(expectedGaps)
|
||||
})
|
||||
|
||||
It("does not modify data when receiving a duplicate", func() {
|
||||
err := s.push([]byte("fffff"), 0, false)
|
||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
||||
Expect(s.queue[0]).ToNot(Equal([]byte("fffff")))
|
||||
})
|
||||
|
||||
It("detects a duplicate frame that is smaller than the original, starting at the beginning", func() {
|
||||
// 10 to 12
|
||||
err := s.push([]byte("12"), 10, false)
|
||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
||||
Expect(s.queue[10]).To(HaveLen(5))
|
||||
})
|
||||
|
||||
It("detects a duplicate frame that is smaller than the original, somewhere in the middle", func() {
|
||||
// 1 to 4
|
||||
err := s.push([]byte("123"), 1, false)
|
||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
||||
Expect(s.queue[0]).To(HaveLen(5))
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(1)))
|
||||
})
|
||||
|
||||
It("detects a duplicate frame that is smaller than the original, somewhere in the middle in the last block", func() {
|
||||
// 11 to 14
|
||||
err := s.push([]byte("123"), 11, false)
|
||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
||||
Expect(s.queue[10]).To(HaveLen(5))
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(11)))
|
||||
})
|
||||
|
||||
It("detects a duplicate frame that is smaller than the original, with aligned end in the last block", func() {
|
||||
// 11 to 15
|
||||
err := s.push([]byte("1234"), 1, false)
|
||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
||||
Expect(s.queue[10]).To(HaveLen(5))
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(11)))
|
||||
})
|
||||
|
||||
It("detects a duplicate frame that is smaller than the original, with aligned end", func() {
|
||||
// 3 to 5
|
||||
err := s.push([]byte("12"), 3, false)
|
||||
Expect(err).To(MatchError(errDuplicateStreamData))
|
||||
Expect(s.queue[0]).To(HaveLen(5))
|
||||
Expect(s.queue).ToNot(HaveKey(protocol.ByteCount(3)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("DoS protection", func() {
|
||||
It("errors when too many gaps are created", func() {
|
||||
for i := 0; i < protocol.MaxStreamFrameSorterGaps; i++ {
|
||||
Expect(s.Push([]byte("foobar"), protocol.ByteCount(i*7), false)).To(Succeed())
|
||||
}
|
||||
Expect(s.gaps.Len()).To(Equal(protocol.MaxStreamFrameSorterGaps))
|
||||
err := s.Push([]byte("foobar"), protocol.ByteCount(protocol.MaxStreamFrameSorterGaps*7)+100, false)
|
||||
Expect(err).To(MatchError("Too many gaps in received data"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,314 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"golang.org/x/net/idna"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type roundTripperOpts struct {
|
||||
DisableCompression bool
|
||||
}
|
||||
|
||||
var dialAddr = quic.DialAddr
|
||||
|
||||
// client is a HTTP2 client doing QUIC requests
|
||||
type client struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *quic.Config
|
||||
opts *roundTripperOpts
|
||||
|
||||
hostname string
|
||||
handshakeErr error
|
||||
dialOnce sync.Once
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
|
||||
session quic.Session
|
||||
headerStream quic.Stream
|
||||
headerErr *qerr.QuicError
|
||||
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
|
||||
requestWriter *requestWriter
|
||||
|
||||
responses map[protocol.StreamID]chan *http.Response
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = &client{}
|
||||
|
||||
var defaultQuicConfig = &quic.Config{
|
||||
RequestConnectionIDOmission: true,
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
// newClient creates a new client
|
||||
func newClient(
|
||||
hostname string,
|
||||
tlsConfig *tls.Config,
|
||||
opts *roundTripperOpts,
|
||||
quicConfig *quic.Config,
|
||||
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
|
||||
) *client {
|
||||
config := defaultQuicConfig
|
||||
if quicConfig != nil {
|
||||
config = quicConfig
|
||||
}
|
||||
return &client{
|
||||
hostname: authorityAddr("https", hostname),
|
||||
responses: make(map[protocol.StreamID]chan *http.Response),
|
||||
tlsConf: tlsConfig,
|
||||
config: config,
|
||||
opts: opts,
|
||||
headerErrored: make(chan struct{}),
|
||||
dialer: dialer,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
}
|
||||
|
||||
// dial dials the connection
|
||||
func (c *client) dial() error {
|
||||
var err error
|
||||
if c.dialer != nil {
|
||||
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
||||
} else {
|
||||
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// once the version has been negotiated, open the header stream
|
||||
c.headerStream, err = c.session.OpenStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.requestWriter = newRequestWriter(c.headerStream, c.logger)
|
||||
go c.handleHeaderStream()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleHeaderStream() {
|
||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||
h2framer := http2.NewFramer(nil, c.headerStream)
|
||||
|
||||
var err error
|
||||
for err == nil {
|
||||
err = c.readResponse(h2framer, decoder)
|
||||
}
|
||||
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
|
||||
c.logger.Debugf("Error handling header stream: %s", err)
|
||||
}
|
||||
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
|
||||
// stop all running request
|
||||
close(c.headerErrored)
|
||||
}
|
||||
|
||||
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
|
||||
frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hframe, ok := frame.(*http2.HeadersFrame)
|
||||
if !ok {
|
||||
return errors.New("not a headers frame")
|
||||
}
|
||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read header fields: %s", err.Error())
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
|
||||
c.mutex.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
|
||||
}
|
||||
|
||||
rsp, err := responseFromHeaders(mhframe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
responseChan <- rsp
|
||||
return nil
|
||||
}
|
||||
|
||||
// Roundtrip executes a request and returns a response
|
||||
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// TODO: add port to address, if it doesn't have one
|
||||
if req.URL.Scheme != "https" {
|
||||
return nil, errors.New("quic http2: unsupported scheme")
|
||||
}
|
||||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
||||
return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
||||
}
|
||||
|
||||
c.dialOnce.Do(func() {
|
||||
c.handshakeErr = c.dial()
|
||||
})
|
||||
|
||||
if c.handshakeErr != nil {
|
||||
return nil, c.handshakeErr
|
||||
}
|
||||
|
||||
hasBody := (req.Body != nil)
|
||||
|
||||
responseChan := make(chan *http.Response)
|
||||
dataStream, err := c.session.OpenStreamSync()
|
||||
if err != nil {
|
||||
_ = c.closeWithError(err)
|
||||
return nil, err
|
||||
}
|
||||
c.mutex.Lock()
|
||||
c.responses[dataStream.StreamID()] = responseChan
|
||||
c.mutex.Unlock()
|
||||
|
||||
var requestedGzip bool
|
||||
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
|
||||
requestedGzip = true
|
||||
}
|
||||
// TODO: add support for trailers
|
||||
endStream := !hasBody
|
||||
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
|
||||
if err != nil {
|
||||
_ = c.closeWithError(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resc := make(chan error, 1)
|
||||
if hasBody {
|
||||
go func() {
|
||||
resc <- c.writeRequestBody(dataStream, req.Body)
|
||||
}()
|
||||
}
|
||||
|
||||
var res *http.Response
|
||||
|
||||
var receivedResponse bool
|
||||
var bodySent bool
|
||||
|
||||
if !hasBody {
|
||||
bodySent = true
|
||||
}
|
||||
|
||||
ctx := req.Context()
|
||||
for !(bodySent && receivedResponse) {
|
||||
select {
|
||||
case res = <-responseChan:
|
||||
receivedResponse = true
|
||||
c.mutex.Lock()
|
||||
delete(c.responses, dataStream.StreamID())
|
||||
c.mutex.Unlock()
|
||||
case err := <-resc:
|
||||
bodySent = true
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
// error code 6 signals that stream was canceled
|
||||
dataStream.CancelRead(6)
|
||||
dataStream.CancelWrite(6)
|
||||
c.mutex.Lock()
|
||||
delete(c.responses, dataStream.StreamID())
|
||||
c.mutex.Unlock()
|
||||
return nil, ctx.Err()
|
||||
case <-c.headerErrored:
|
||||
// an error occurred on the header stream
|
||||
_ = c.closeWithError(c.headerErr)
|
||||
return nil, c.headerErr
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: correctly set this variable
|
||||
var streamEnded bool
|
||||
isHead := (req.Method == "HEAD")
|
||||
|
||||
res = setLength(res, isHead, streamEnded)
|
||||
|
||||
if streamEnded || isHead {
|
||||
res.Body = noBody
|
||||
} else {
|
||||
res.Body = dataStream
|
||||
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
||||
res.Header.Del("Content-Encoding")
|
||||
res.Header.Del("Content-Length")
|
||||
res.ContentLength = -1
|
||||
res.Body = &gzipReader{body: res.Body}
|
||||
res.Uncompressed = true
|
||||
}
|
||||
}
|
||||
|
||||
res.Request = req
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
|
||||
defer func() {
|
||||
cerr := body.Close()
|
||||
if err == nil {
|
||||
// TODO: what to do with dataStream here? Maybe reset it?
|
||||
err = cerr
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = io.Copy(dataStream, body)
|
||||
if err != nil {
|
||||
// TODO: what to do with dataStream here? Maybe reset it?
|
||||
return err
|
||||
}
|
||||
return dataStream.Close()
|
||||
}
|
||||
|
||||
func (c *client) closeWithError(e error) error {
|
||||
if c.session == nil {
|
||||
return nil
|
||||
}
|
||||
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
|
||||
}
|
||||
|
||||
// Close closes the client
|
||||
func (c *client) Close() error {
|
||||
if c.session == nil {
|
||||
return nil
|
||||
}
|
||||
return c.session.Close()
|
||||
}
|
||||
|
||||
// copied from net/transport.go
|
||||
|
||||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
||||
// and returns a host:port. The port 443 is added if needed.
|
||||
func authorityAddr(scheme string, authority string) (addr string) {
|
||||
host, port, err := net.SplitHostPort(authority)
|
||||
if err != nil { // authority didn't have a port
|
||||
port = "443"
|
||||
if scheme == "http" {
|
||||
port = "80"
|
||||
}
|
||||
host = authority
|
||||
}
|
||||
if a, err := idna.ToASCII(host); err == nil {
|
||||
host = a
|
||||
}
|
||||
// IPv6 address literal, without a port:
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
return host + ":" + port
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
|
@ -0,0 +1,639 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
var (
|
||||
client *client
|
||||
session *mockSession
|
||||
headerStream *mockStream
|
||||
req *http.Request
|
||||
origDialAddr = dialAddr
|
||||
)
|
||||
|
||||
injectResponse := func(id protocol.StreamID, rsp *http.Response) {
|
||||
EventuallyWithOffset(0, func() bool {
|
||||
client.mutex.Lock()
|
||||
defer client.mutex.Unlock()
|
||||
_, ok := client.responses[id]
|
||||
return ok
|
||||
}).Should(BeTrue())
|
||||
rspChan := client.responses[5]
|
||||
ExpectWithOffset(0, rspChan).ToNot(BeClosed())
|
||||
rspChan <- rsp
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
origDialAddr = dialAddr
|
||||
hostname := "quic.clemente.io:1337"
|
||||
client = newClient(hostname, nil, &roundTripperOpts{}, nil, nil)
|
||||
Expect(client.hostname).To(Equal(hostname))
|
||||
session = newMockSession()
|
||||
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
||||
client.session = session
|
||||
|
||||
headerStream = newMockStream(3)
|
||||
client.headerStream = headerStream
|
||||
client.requestWriter = newRequestWriter(headerStream, utils.DefaultLogger)
|
||||
var err error
|
||||
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
dialAddr = origDialAddr
|
||||
})
|
||||
|
||||
It("saves the TLS config", func() {
|
||||
tlsConf := &tls.Config{InsecureSkipVerify: true}
|
||||
client = newClient("", tlsConf, &roundTripperOpts{}, nil, nil)
|
||||
Expect(client.tlsConf).To(Equal(tlsConf))
|
||||
})
|
||||
|
||||
It("saves the QUIC config", func() {
|
||||
quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf, nil)
|
||||
Expect(client.config).To(Equal(quicConf))
|
||||
})
|
||||
|
||||
It("uses the default QUIC config if none is give", func() {
|
||||
client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil, nil)
|
||||
Expect(client.config).ToNot(BeNil())
|
||||
Expect(client.config).To(Equal(defaultQuicConfig))
|
||||
})
|
||||
|
||||
It("adds the port to the hostname, if none is given", func() {
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||
Expect(client.hostname).To(Equal("quic.clemente.io:443"))
|
||||
})
|
||||
|
||||
It("dials", func() {
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
close(headerStream.unblockRead)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := client.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
// fmt.Println("done")
|
||||
}()
|
||||
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
||||
// make the go routine return
|
||||
injectResponse(5, &http.Response{})
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors when dialing fails", func() {
|
||||
testErr := errors.New("handshake error")
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
_, err := client.RoundTrip(req)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("uses the custom dialer, if provided", func() {
|
||||
var tlsCfg *tls.Config
|
||||
var qCfg *quic.Config
|
||||
session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)}
|
||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
||||
tlsCfg = tlsCfgP
|
||||
qCfg = cfg
|
||||
return session, nil
|
||||
}
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, dialer)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := client.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
Eventually(func() quic.Session { return client.session }).Should(Equal(session))
|
||||
Expect(qCfg).To(Equal(client.config))
|
||||
Expect(tlsCfg).To(Equal(client.tlsConf))
|
||||
// make the go routine return
|
||||
injectResponse(5, &http.Response{})
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors if it can't open a stream", func() {
|
||||
testErr := errors.New("you shall not pass")
|
||||
client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
|
||||
session.streamOpenErr = testErr
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
_, err := client.RoundTrip(req)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("returns a request when dial fails", func() {
|
||||
testErr := errors.New("dial error")
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
close(done)
|
||||
}()
|
||||
_, err = client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("Doing requests", func() {
|
||||
var request *http.Request
|
||||
var dataStream *mockStream
|
||||
|
||||
getRequest := func(data []byte) *http2.MetaHeadersFrame {
|
||||
r := bytes.NewReader(data)
|
||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||
h2framer := http2.NewFramer(nil, r)
|
||||
frame, err := h2framer.ReadFrame()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)}
|
||||
mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return mhframe
|
||||
}
|
||||
|
||||
getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string {
|
||||
fields := make(map[string]string)
|
||||
for _, hf := range f.Fields {
|
||||
fields[hf.Name] = hf.Value
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
dataStream = newMockStream(5)
|
||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
|
||||
request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("does a request", func() {
|
||||
teapot := &http.Response{
|
||||
Status: "418 I'm a teapot",
|
||||
StatusCode: 418,
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp).To(Equal(teapot))
|
||||
Expect(rsp.Body).To(Equal(dataStream))
|
||||
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
|
||||
Expect(rsp.Request).To(Equal(request))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
||||
injectResponse(5, teapot)
|
||||
Expect(client.headerErrored).ToNot(BeClosed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors if a request without a body is canceled", func() {
|
||||
done := make(chan struct{})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
request = request.WithContext(ctx)
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
Expect(rsp).To(BeNil())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
cancel()
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(dataStream.reset).To(BeTrue())
|
||||
Expect(dataStream.canceledWrite).To(BeTrue())
|
||||
Expect(client.headerErrored).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("errors if a request with a body is canceled after the body is sent", func() {
|
||||
done := make(chan struct{})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
request = request.WithContext(ctx)
|
||||
request.Body = &mockBody{}
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
Expect(rsp).To(BeNil())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(dataStream.reset).To(BeTrue())
|
||||
Expect(dataStream.canceledWrite).To(BeTrue())
|
||||
Expect(client.headerErrored).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("errors if a request with a body is canceled before the body is sent", func() {
|
||||
done := make(chan struct{})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
request = request.WithContext(ctx)
|
||||
request.Body = &mockBody{}
|
||||
cancel()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
Expect(rsp).To(BeNil())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(dataStream.reset).To(BeTrue())
|
||||
Expect(dataStream.canceledWrite).To(BeTrue())
|
||||
Expect(client.headerErrored).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("closes the quic client when encountering an error on the header stream", func() {
|
||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(client.headerErr))
|
||||
Expect(rsp).To(BeNil())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
||||
Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr))
|
||||
})
|
||||
|
||||
It("returns subsequent request if there was an error on the header stream before", func() {
|
||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)}
|
||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
||||
_, err := client.RoundTrip(request)
|
||||
Expect(err).To(BeAssignableToTypeOf(&qerr.QuicError{}))
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
||||
// now that the first request failed due to an error on the header stream, try another request
|
||||
_, nextErr := client.RoundTrip(request)
|
||||
Expect(nextErr).To(MatchError(err))
|
||||
})
|
||||
|
||||
It("blocks if no stream is available", func() {
|
||||
session.streamsToOpen = []quic.Stream{headerStream, dataStream}
|
||||
session.blockOpenStreamSync = true
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := client.RoundTrip(request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
// make the go routine return
|
||||
client.Close()
|
||||
injectResponse(5, &http.Response{})
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("validating the address", func() {
|
||||
It("refuses to do requests for the wrong host", func() {
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = client.RoundTrip(req)
|
||||
Expect(err).To(MatchError("h2quic Client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
||||
})
|
||||
|
||||
It("refuses to do plain HTTP requests", func() {
|
||||
req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = client.RoundTrip(req)
|
||||
Expect(err).To(MatchError("quic http2: unsupported scheme"))
|
||||
})
|
||||
|
||||
It("adds the port for request URLs without one", func() {
|
||||
client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
// the client.RoundTrip will block, because the encryption level is still set to Unencrypted
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := client.RoundTrip(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Consistently(done).ShouldNot(BeClosed())
|
||||
// make the go routine return
|
||||
injectResponse(5, &http.Response{})
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
It("sets the EndStream header for requests without a body", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
client.RoundTrip(request)
|
||||
close(done)
|
||||
}()
|
||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
|
||||
mhf := getRequest(headerStream.dataWritten.Bytes())
|
||||
Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue())
|
||||
// make the go routine return
|
||||
injectResponse(5, &http.Response{})
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("sets the EndStream header to false for requests with a body", func() {
|
||||
request.Body = &mockBody{}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
client.RoundTrip(request)
|
||||
close(done)
|
||||
}()
|
||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil())
|
||||
mhf := getRequest(headerStream.dataWritten.Bytes())
|
||||
Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse())
|
||||
// make the go routine return
|
||||
injectResponse(5, &http.Response{})
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("requests containing a Body", func() {
|
||||
var requestBody []byte
|
||||
var response *http.Response
|
||||
|
||||
BeforeEach(func() {
|
||||
requestBody = []byte("request body")
|
||||
body := &mockBody{}
|
||||
body.SetData(requestBody)
|
||||
request.Body = body
|
||||
response = &http.Response{
|
||||
StatusCode: 200,
|
||||
Header: http.Header{"Content-Length": []string{"1000"}},
|
||||
}
|
||||
// fake a handshake
|
||||
client.dialOnce.Do(func() {})
|
||||
session.streamsToOpen = []quic.Stream{dataStream}
|
||||
})
|
||||
|
||||
It("sends a request", func() {
|
||||
rspChan := make(chan *http.Response)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rspChan <- rsp
|
||||
}()
|
||||
injectResponse(5, response)
|
||||
Eventually(rspChan).Should(Receive(Equal(response)))
|
||||
Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody))
|
||||
Expect(dataStream.closed).To(BeTrue())
|
||||
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns the error that occurred when reading the body", func() {
|
||||
testErr := errors.New("testErr")
|
||||
request.Body.(*mockBody).readErr = testErr
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(rsp).To(BeNil())
|
||||
close(done)
|
||||
}()
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns the error that occurred when closing the body", func() {
|
||||
testErr := errors.New("testErr")
|
||||
request.Body.(*mockBody).closeErr = testErr
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(rsp).To(BeNil())
|
||||
close(done)
|
||||
}()
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(request.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("gzip compression", func() {
|
||||
var gzippedData []byte // a gzipped foobar
|
||||
var response *http.Response
|
||||
|
||||
BeforeEach(func() {
|
||||
var b bytes.Buffer
|
||||
w := gzip.NewWriter(&b)
|
||||
w.Write([]byte("foobar"))
|
||||
w.Close()
|
||||
gzippedData = b.Bytes()
|
||||
response = &http.Response{
|
||||
StatusCode: 200,
|
||||
Header: http.Header{"Content-Length": []string{"1000"}},
|
||||
}
|
||||
})
|
||||
|
||||
It("adds the gzip header to requests", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp).ToNot(BeNil())
|
||||
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
|
||||
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
|
||||
Expect(rsp.Header.Get("Content-Length")).To(BeEmpty())
|
||||
data := make([]byte, 6)
|
||||
_, err = io.ReadFull(rsp.Body, data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal([]byte("foobar")))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
dataStream.dataToRead.Write(gzippedData)
|
||||
response.Header.Add("Content-Encoding", "gzip")
|
||||
injectResponse(5, response)
|
||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
||||
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||
close(dataStream.unblockRead)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't add gzip if the header disable it", func() {
|
||||
client.opts.DisableCompression = true
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := client.RoundTrip(request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}()
|
||||
|
||||
Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty())
|
||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
||||
Expect(headers).ToNot(HaveKey("accept-encoding"))
|
||||
// make the go routine return
|
||||
injectResponse(5, &http.Response{})
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("only decompresses the response if the response contains the right content-encoding header", func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp).ToNot(BeNil())
|
||||
data := make([]byte, 11)
|
||||
rsp.Body.Read(data)
|
||||
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
|
||||
Expect(data).To(Equal([]byte("not gzipped")))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
dataStream.dataToRead.Write([]byte("not gzipped"))
|
||||
injectResponse(5, response)
|
||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
||||
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't add the gzip header for requests that have the accept-enconding set", func() {
|
||||
request.Header.Add("accept-encoding", "gzip")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
rsp, err := client.RoundTrip(request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data := make([]byte, 12)
|
||||
_, err = rsp.Body.Read(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
|
||||
Expect(data).To(Equal([]byte("gzipped data")))
|
||||
close(done)
|
||||
}()
|
||||
|
||||
dataStream.dataToRead.Write([]byte("gzipped data"))
|
||||
injectResponse(5, response)
|
||||
headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes()))
|
||||
Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("handling the header stream", func() {
|
||||
var h2framer *http2.Framer
|
||||
|
||||
BeforeEach(func() {
|
||||
h2framer = http2.NewFramer(&headerStream.dataToRead, nil)
|
||||
client.responses[23] = make(chan *http.Response)
|
||||
})
|
||||
|
||||
It("reads header values from a response", func() {
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
data := []byte{0x48, 0x03, 0x33, 0x30, 0x32, 0x58, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x61, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x31, 0x20, 0x4f, 0x63, 0x74, 0x20, 0x32, 0x30, 0x31, 0x33, 0x20, 0x32, 0x30, 0x3a, 0x31, 0x33, 0x3a, 0x32, 0x31, 0x20, 0x47, 0x4d, 0x54, 0x6e, 0x17, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d}
|
||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, byte(len(data)), 0x1, 0x5, 0x0, 0x0, 0x0, 23})
|
||||
headerStream.dataToRead.Write(data)
|
||||
go client.handleHeaderStream()
|
||||
var rsp *http.Response
|
||||
Eventually(client.responses[23]).Should(Receive(&rsp))
|
||||
Expect(rsp).ToNot(BeNil())
|
||||
Expect(rsp.Proto).To(Equal("HTTP/2.0"))
|
||||
Expect(rsp.ProtoMajor).To(BeEquivalentTo(2))
|
||||
Expect(rsp.StatusCode).To(BeEquivalentTo(302))
|
||||
Expect(rsp.Status).To(Equal("302 Found"))
|
||||
Expect(rsp.Header).To(HaveKeyWithValue("Location", []string{"https://www.example.com"}))
|
||||
Expect(rsp.Header).To(HaveKeyWithValue("Cache-Control", []string{"private"}))
|
||||
})
|
||||
|
||||
It("errors if the H2 frame is not a HeadersFrame", func() {
|
||||
h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0})
|
||||
client.handleHeaderStream()
|
||||
Eventually(client.headerErrored).Should(BeClosed())
|
||||
Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")))
|
||||
})
|
||||
|
||||
It("errors if it can't read the HPACK encoded header fields", func() {
|
||||
h2framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: 23,
|
||||
EndHeaders: true,
|
||||
BlockFragment: []byte("invalid HPACK data"),
|
||||
})
|
||||
client.handleHeaderStream()
|
||||
Eventually(client.headerErrored).Should(BeClosed())
|
||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
||||
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("cannot read header fields"))
|
||||
})
|
||||
|
||||
It("errors if the stream cannot be found", func() {
|
||||
var headers bytes.Buffer
|
||||
enc := hpack.NewEncoder(&headers)
|
||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
|
||||
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: 1337,
|
||||
EndHeaders: true,
|
||||
BlockFragment: headers.Bytes(),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
client.handleHeaderStream()
|
||||
Eventually(client.headerErrored).Should(BeClosed())
|
||||
Expect(client.headerErr.ErrorCode).To(Equal(qerr.InvalidHeadersStreamData))
|
||||
Expect(client.headerErr.ErrorMessage).To(ContainSubstring("response channel for stream 1337 not found"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,35 @@
|
|||
package h2quic
|
||||
|
||||
// copied from net/transport.go
|
||||
|
||||
// gzipReader wraps a response body so it can lazily
|
||||
// call gzip.NewReader on the first call to Read
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
)
|
||||
|
||||
// call gzip.NewReader on the first call to Read
|
||||
type gzipReader struct {
|
||||
body io.ReadCloser // underlying Response.Body
|
||||
zr *gzip.Reader // lazily-initialized gzip reader
|
||||
zerr error // sticky error
|
||||
}
|
||||
|
||||
func (gz *gzipReader) Read(p []byte) (n int, err error) {
|
||||
if gz.zerr != nil {
|
||||
return 0, gz.zerr
|
||||
}
|
||||
if gz.zr == nil {
|
||||
gz.zr, err = gzip.NewReader(gz.body)
|
||||
if err != nil {
|
||||
gz.zerr = err
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return gz.zr.Read(p)
|
||||
}
|
||||
|
||||
func (gz *gzipReader) Close() error {
|
||||
return gz.body.Close()
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestH2quic(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "H2quic Suite")
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
|
||||
var path, authority, method, contentLengthStr string
|
||||
httpHeaders := http.Header{}
|
||||
|
||||
for _, h := range headers {
|
||||
switch h.Name {
|
||||
case ":path":
|
||||
path = h.Value
|
||||
case ":method":
|
||||
method = h.Value
|
||||
case ":authority":
|
||||
authority = h.Value
|
||||
case "content-length":
|
||||
contentLengthStr = h.Value
|
||||
default:
|
||||
if !h.IsPseudo() {
|
||||
httpHeaders.Add(h.Name, h.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
|
||||
if len(httpHeaders["Cookie"]) > 0 {
|
||||
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
|
||||
}
|
||||
|
||||
if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
|
||||
return nil, errors.New(":path, :authority and :method must not be empty")
|
||||
}
|
||||
|
||||
u, err := url.Parse(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var contentLength int64
|
||||
if len(contentLengthStr) > 0 {
|
||||
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
Method: method,
|
||||
URL: u,
|
||||
Proto: "HTTP/2.0",
|
||||
ProtoMajor: 2,
|
||||
ProtoMinor: 0,
|
||||
Header: httpHeaders,
|
||||
Body: nil,
|
||||
ContentLength: contentLength,
|
||||
Host: authority,
|
||||
RequestURI: path,
|
||||
TLS: &tls.ConnectionState{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func hostnameFromRequest(req *http.Request) string {
|
||||
if req.URL != nil {
|
||||
return req.URL.Host
|
||||
}
|
||||
return ""
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
type requestBody struct {
|
||||
requestRead bool
|
||||
dataStream quic.Stream
|
||||
}
|
||||
|
||||
// make sure the requestBody can be used as a http.Request.Body
|
||||
var _ io.ReadCloser = &requestBody{}
|
||||
|
||||
func newRequestBody(stream quic.Stream) *requestBody {
|
||||
return &requestBody{dataStream: stream}
|
||||
}
|
||||
|
||||
func (b *requestBody) Read(p []byte) (int, error) {
|
||||
b.requestRead = true
|
||||
return b.dataStream.Read(p)
|
||||
}
|
||||
|
||||
func (b *requestBody) Close() error {
|
||||
// stream's Close() closes the write side, not the read side
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Request body", func() {
|
||||
var (
|
||||
stream *mockStream
|
||||
rb *requestBody
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
stream = &mockStream{}
|
||||
stream.dataToRead.Write([]byte("foobar")) // provides data to be read
|
||||
rb = newRequestBody(stream)
|
||||
})
|
||||
|
||||
It("reads from the stream", func() {
|
||||
b := make([]byte, 10)
|
||||
n, _ := stream.Read(b)
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(b[0:6]).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("saves if the stream was read from", func() {
|
||||
Expect(rb.requestRead).To(BeFalse())
|
||||
rb.Read(make([]byte, 1))
|
||||
Expect(rb.requestRead).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't close the stream when closing the request body", func() {
|
||||
Expect(stream.closed).To(BeFalse())
|
||||
err := rb.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream.closed).To(BeFalse())
|
||||
})
|
||||
})
|
|
@ -0,0 +1,121 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Request", func() {
|
||||
It("populates request", func() {
|
||||
headers := []hpack.HeaderField{
|
||||
{Name: ":path", Value: "/foo"},
|
||||
{Name: ":authority", Value: "quic.clemente.io"},
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: "content-length", Value: "42"},
|
||||
}
|
||||
req, err := requestFromHeaders(headers)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(req.Method).To(Equal("GET"))
|
||||
Expect(req.URL.Path).To(Equal("/foo"))
|
||||
Expect(req.Proto).To(Equal("HTTP/2.0"))
|
||||
Expect(req.ProtoMajor).To(Equal(2))
|
||||
Expect(req.ProtoMinor).To(Equal(0))
|
||||
Expect(req.ContentLength).To(Equal(int64(42)))
|
||||
Expect(req.Header).To(BeEmpty())
|
||||
Expect(req.Body).To(BeNil())
|
||||
Expect(req.Host).To(Equal("quic.clemente.io"))
|
||||
Expect(req.RequestURI).To(Equal("/foo"))
|
||||
Expect(req.TLS).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("concatenates the cookie headers", func() {
|
||||
headers := []hpack.HeaderField{
|
||||
{Name: ":path", Value: "/foo"},
|
||||
{Name: ":authority", Value: "quic.clemente.io"},
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: "cookie", Value: "cookie1=foobar1"},
|
||||
{Name: "cookie", Value: "cookie2=foobar2"},
|
||||
}
|
||||
req, err := requestFromHeaders(headers)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(req.Header).To(Equal(http.Header{
|
||||
"Cookie": []string{"cookie1=foobar1; cookie2=foobar2"},
|
||||
}))
|
||||
})
|
||||
|
||||
It("handles other headers", func() {
|
||||
headers := []hpack.HeaderField{
|
||||
{Name: ":path", Value: "/foo"},
|
||||
{Name: ":authority", Value: "quic.clemente.io"},
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: "cache-control", Value: "max-age=0"},
|
||||
{Name: "duplicate-header", Value: "1"},
|
||||
{Name: "duplicate-header", Value: "2"},
|
||||
}
|
||||
req, err := requestFromHeaders(headers)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(req.Header).To(Equal(http.Header{
|
||||
"Cache-Control": []string{"max-age=0"},
|
||||
"Duplicate-Header": []string{"1", "2"},
|
||||
}))
|
||||
})
|
||||
|
||||
It("errors with missing path", func() {
|
||||
headers := []hpack.HeaderField{
|
||||
{Name: ":authority", Value: "quic.clemente.io"},
|
||||
{Name: ":method", Value: "GET"},
|
||||
}
|
||||
_, err := requestFromHeaders(headers)
|
||||
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
|
||||
})
|
||||
|
||||
It("errors with missing method", func() {
|
||||
headers := []hpack.HeaderField{
|
||||
{Name: ":path", Value: "/foo"},
|
||||
{Name: ":authority", Value: "quic.clemente.io"},
|
||||
}
|
||||
_, err := requestFromHeaders(headers)
|
||||
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
|
||||
})
|
||||
|
||||
It("errors with missing authority", func() {
|
||||
headers := []hpack.HeaderField{
|
||||
{Name: ":path", Value: "/foo"},
|
||||
{Name: ":method", Value: "GET"},
|
||||
}
|
||||
_, err := requestFromHeaders(headers)
|
||||
Expect(err).To(MatchError(":path, :authority and :method must not be empty"))
|
||||
})
|
||||
|
||||
Context("extracting the hostname from a request", func() {
|
||||
var url *url.URL
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
url, err = url.Parse("https://quic.clemente.io:1337")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("uses req.URL.Host", func() {
|
||||
req := &http.Request{URL: url}
|
||||
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
|
||||
})
|
||||
|
||||
It("uses req.URL.Host even if req.Host is available", func() {
|
||||
req := &http.Request{
|
||||
Host: "www.example.org",
|
||||
URL: url,
|
||||
}
|
||||
Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337"))
|
||||
})
|
||||
|
||||
It("returns an empty hostname if nothing is set", func() {
|
||||
Expect(hostnameFromRequest(&http.Request{})).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,203 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type requestWriter struct {
|
||||
mutex sync.Mutex
|
||||
headerStream quic.Stream
|
||||
|
||||
henc *hpack.Encoder
|
||||
hbuf bytes.Buffer // HPACK encoder writes into this
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
const defaultUserAgent = "quic-go"
|
||||
|
||||
func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
|
||||
rw := &requestWriter{
|
||||
headerStream: headerStream,
|
||||
logger: logger,
|
||||
}
|
||||
rw.henc = hpack.NewEncoder(&rw.hbuf)
|
||||
return rw
|
||||
}
|
||||
|
||||
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
|
||||
// TODO: add support for trailers
|
||||
// TODO: add support for gzip compression
|
||||
// TODO: write continuation frames, if the header frame is too long
|
||||
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
|
||||
h2framer := http2.NewFramer(w.headerStream, nil)
|
||||
return h2framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: uint32(dataStreamID),
|
||||
EndHeaders: true,
|
||||
EndStream: endStream,
|
||||
BlockFragment: w.hbuf.Bytes(),
|
||||
Priority: http2.PriorityParam{Weight: 0xff},
|
||||
})
|
||||
}
|
||||
|
||||
// the rest of this files is copied from http2.Transport
|
||||
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
|
||||
w.hbuf.Reset()
|
||||
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
host, err := httpguts.PunycodeHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var path string
|
||||
if req.Method != "CONNECT" {
|
||||
path = req.URL.RequestURI()
|
||||
if !validPseudoPath(path) {
|
||||
orig := path
|
||||
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
|
||||
if !validPseudoPath(path) {
|
||||
if req.URL.Opaque != "" {
|
||||
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid request :path %q", orig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for any invalid headers and return an error before we
|
||||
// potentially pollute our hpack state. (We want to be able to
|
||||
// continue to reuse the hpack encoder for future requests)
|
||||
for k, vv := range req.Header {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("invalid HTTP header name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 8.1.2.3 Request Pseudo-Header Fields
|
||||
// The :path pseudo-header field includes the path and query parts of the
|
||||
// target URI (the path-absolute production and optionally a '?' character
|
||||
// followed by the query production (see Sections 3.3 and 3.4 of
|
||||
// [RFC3986]).
|
||||
w.writeHeader(":authority", host)
|
||||
w.writeHeader(":method", req.Method)
|
||||
if req.Method != "CONNECT" {
|
||||
w.writeHeader(":path", path)
|
||||
w.writeHeader(":scheme", req.URL.Scheme)
|
||||
}
|
||||
if trailers != "" {
|
||||
w.writeHeader("trailer", trailers)
|
||||
}
|
||||
|
||||
var didUA bool
|
||||
for k, vv := range req.Header {
|
||||
lowKey := strings.ToLower(k)
|
||||
switch lowKey {
|
||||
case "host", "content-length":
|
||||
// Host is :authority, already sent.
|
||||
// Content-Length is automatic, set below.
|
||||
continue
|
||||
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
|
||||
// Per 8.1.2.2 Connection-Specific Header
|
||||
// Fields, don't send connection-specific
|
||||
// fields. We have already checked if any
|
||||
// are error-worthy so just ignore the rest.
|
||||
continue
|
||||
case "user-agent":
|
||||
// Match Go's http1 behavior: at most one
|
||||
// User-Agent. If set to nil or empty string,
|
||||
// then omit it. Otherwise if not mentioned,
|
||||
// include the default (below).
|
||||
didUA = true
|
||||
if len(vv) < 1 {
|
||||
continue
|
||||
}
|
||||
vv = vv[:1]
|
||||
if vv[0] == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, v := range vv {
|
||||
w.writeHeader(lowKey, v)
|
||||
}
|
||||
}
|
||||
if shouldSendReqContentLength(req.Method, contentLength) {
|
||||
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
|
||||
}
|
||||
if addGzipHeader {
|
||||
w.writeHeader("accept-encoding", "gzip")
|
||||
}
|
||||
if !didUA {
|
||||
w.writeHeader("user-agent", defaultUserAgent)
|
||||
}
|
||||
return w.hbuf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (w *requestWriter) writeHeader(name, value string) {
|
||||
w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
|
||||
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
||||
}
|
||||
|
||||
// shouldSendReqContentLength reports whether the http2.Transport should send
|
||||
// a "content-length" request header. This logic is basically a copy of the net/http
|
||||
// transferWriter.shouldSendContentLength.
|
||||
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
||||
// -1 means unknown.
|
||||
func shouldSendReqContentLength(method string, contentLength int64) bool {
|
||||
if contentLength > 0 {
|
||||
return true
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return false
|
||||
}
|
||||
// For zero bodies, whether we send a content-length depends on the method.
|
||||
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
||||
switch method {
|
||||
case "POST", "PUT", "PATCH":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validPseudoPath(v string) bool {
|
||||
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
|
||||
}
|
||||
|
||||
// actualContentLength returns a sanitized version of
|
||||
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
||||
// means unknown.
|
||||
func actualContentLength(req *http.Request) int64 {
|
||||
if req.Body == nil {
|
||||
return 0
|
||||
}
|
||||
if req.ContentLength != 0 {
|
||||
return req.ContentLength
|
||||
}
|
||||
return -1
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Request", func() {
|
||||
var (
|
||||
rw *requestWriter
|
||||
headerStream *mockStream
|
||||
decoder *hpack.Decoder
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
headerStream = &mockStream{}
|
||||
rw = newRequestWriter(headerStream, utils.DefaultLogger)
|
||||
decoder = hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||
})
|
||||
|
||||
decode := func(p []byte) (*http2.HeadersFrame, map[string] /* HeaderField.Name */ string /* HeaderField.Value */) {
|
||||
framer := http2.NewFramer(nil, bytes.NewReader(p))
|
||||
frame, err := framer.ReadFrame()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headerFrame := frame.(*http2.HeadersFrame)
|
||||
fields, err := decoder.DecodeFull(headerFrame.HeaderBlockFragment())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
values := make(map[string]string)
|
||||
for _, headerField := range fields {
|
||||
values[headerField.Name] = headerField.Value
|
||||
}
|
||||
return headerFrame, values
|
||||
}
|
||||
|
||||
It("writes a GET request", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rw.WriteRequest(req, 1337, true, false)
|
||||
headerFrame, headerFields := decode(headerStream.dataWritten.Bytes())
|
||||
Expect(headerFrame.StreamID).To(Equal(uint32(1337)))
|
||||
Expect(headerFrame.HasPriority()).To(BeTrue())
|
||||
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar"))
|
||||
Expect(headerFields).To(HaveKeyWithValue(":scheme", "https"))
|
||||
Expect(headerFields).ToNot(HaveKey("accept-encoding"))
|
||||
})
|
||||
|
||||
It("sets the EndStream header", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rw.WriteRequest(req, 1337, true, false)
|
||||
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
|
||||
Expect(headerFrame.StreamEnded()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't set the EndStream header, if requested", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rw.WriteRequest(req, 1337, false, false)
|
||||
headerFrame, _ := decode(headerStream.dataWritten.Bytes())
|
||||
Expect(headerFrame.StreamEnded()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("requests gzip compression, if requested", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rw.WriteRequest(req, 1337, true, true)
|
||||
_, headerFields := decode(headerStream.dataWritten.Bytes())
|
||||
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||
})
|
||||
|
||||
It("writes a POST request", func() {
|
||||
form := url.Values{}
|
||||
form.Add("foo", "bar")
|
||||
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", strings.NewReader(form.Encode()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
rw.WriteRequest(req, 5, true, false)
|
||||
_, headerFields := decode(headerStream.dataWritten.Bytes())
|
||||
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
|
||||
Expect(headerFields).To(HaveKey("content-length"))
|
||||
contentLength, err := strconv.Atoi(headerFields["content-length"])
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(contentLength).To(BeNumerically(">", 0))
|
||||
})
|
||||
|
||||
It("sends cookies", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cookie1 := &http.Cookie{
|
||||
Name: "Cookie #1",
|
||||
Value: "Value #1",
|
||||
}
|
||||
cookie2 := &http.Cookie{
|
||||
Name: "Cookie #2",
|
||||
Value: "Value #2",
|
||||
}
|
||||
req.AddCookie(cookie1)
|
||||
req.AddCookie(cookie2)
|
||||
rw.WriteRequest(req, 11, true, false)
|
||||
_, headerFields := decode(headerStream.dataWritten.Bytes())
|
||||
// TODO(lclemente): Remove Or() once we drop support for Go 1.8.
|
||||
Expect(headerFields).To(Or(
|
||||
HaveKeyWithValue("cookie", "Cookie #1=Value #1; Cookie #2=Value #2"),
|
||||
HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`),
|
||||
))
|
||||
})
|
||||
})
|
|
@ -0,0 +1,95 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// copied from net/http2/transport.go
|
||||
|
||||
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
|
||||
var noBody = ioutil.NopCloser(bytes.NewReader(nil))
|
||||
|
||||
// from the handleResponse function
|
||||
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
|
||||
if f.Truncated {
|
||||
return nil, errResponseHeaderListSize
|
||||
}
|
||||
|
||||
status := f.PseudoValue("status")
|
||||
if status == "" {
|
||||
return nil, errors.New("missing status pseudo header")
|
||||
}
|
||||
statusCode, err := strconv.Atoi(status)
|
||||
if err != nil {
|
||||
return nil, errors.New("malformed non-numeric status pseudo header")
|
||||
}
|
||||
|
||||
// TODO: handle statusCode == 100
|
||||
|
||||
header := make(http.Header)
|
||||
res := &http.Response{
|
||||
Proto: "HTTP/2.0",
|
||||
ProtoMajor: 2,
|
||||
Header: header,
|
||||
StatusCode: statusCode,
|
||||
Status: status + " " + http.StatusText(statusCode),
|
||||
}
|
||||
for _, hf := range f.RegularFields() {
|
||||
key := http.CanonicalHeaderKey(hf.Name)
|
||||
if key == "Trailer" {
|
||||
t := res.Trailer
|
||||
if t == nil {
|
||||
t = make(http.Header)
|
||||
res.Trailer = t
|
||||
}
|
||||
foreachHeaderElement(hf.Value, func(v string) {
|
||||
t[http.CanonicalHeaderKey(v)] = nil
|
||||
})
|
||||
} else {
|
||||
header[key] = append(header[key], hf.Value)
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// continuation of the handleResponse function
|
||||
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
|
||||
if !streamEnded || isHead {
|
||||
res.ContentLength = -1
|
||||
if clens := res.Header["Content-Length"]; len(clens) == 1 {
|
||||
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
|
||||
res.ContentLength = clen64
|
||||
}
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// copied from net/http/server.go
|
||||
|
||||
// foreachHeaderElement splits v according to the "#rule" construction
|
||||
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
|
||||
func foreachHeaderElement(v string, fn func(string)) {
|
||||
v = textproto.TrimString(v)
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
if !strings.Contains(v, ",") {
|
||||
fn(v)
|
||||
return
|
||||
}
|
||||
for _, f := range strings.Split(v, ",") {
|
||||
if f = textproto.TrimString(f); f != "" {
|
||||
fn(f)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
type responseWriter struct {
|
||||
dataStreamID protocol.StreamID
|
||||
dataStream quic.Stream
|
||||
|
||||
headerStream quic.Stream
|
||||
headerStreamMutex *sync.Mutex
|
||||
|
||||
header http.Header
|
||||
status int // status code passed to WriteHeader
|
||||
headerWritten bool
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newResponseWriter(
|
||||
headerStream quic.Stream,
|
||||
headerStreamMutex *sync.Mutex,
|
||||
dataStream quic.Stream,
|
||||
dataStreamID protocol.StreamID,
|
||||
logger utils.Logger,
|
||||
) *responseWriter {
|
||||
return &responseWriter{
|
||||
header: http.Header{},
|
||||
headerStream: headerStream,
|
||||
headerStreamMutex: headerStreamMutex,
|
||||
dataStream: dataStream,
|
||||
dataStreamID: dataStreamID,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteHeader(status int) {
|
||||
if w.headerWritten {
|
||||
return
|
||||
}
|
||||
w.headerWritten = true
|
||||
w.status = status
|
||||
|
||||
var headers bytes.Buffer
|
||||
enc := hpack.NewEncoder(&headers)
|
||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
|
||||
|
||||
for k, v := range w.header {
|
||||
for index := range v {
|
||||
enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
|
||||
}
|
||||
}
|
||||
|
||||
w.logger.Infof("Responding with %d", status)
|
||||
w.headerStreamMutex.Lock()
|
||||
defer w.headerStreamMutex.Unlock()
|
||||
h2framer := http2.NewFramer(w.headerStream, nil)
|
||||
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: uint32(w.dataStreamID),
|
||||
EndHeaders: true,
|
||||
BlockFragment: headers.Bytes(),
|
||||
})
|
||||
if err != nil {
|
||||
w.logger.Errorf("could not write h2 header: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Write(p []byte) (int, error) {
|
||||
if !w.headerWritten {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
if !bodyAllowedForStatus(w.status) {
|
||||
return 0, http.ErrBodyNotAllowed
|
||||
}
|
||||
return w.dataStream.Write(p)
|
||||
}
|
||||
|
||||
func (w *responseWriter) Flush() {}
|
||||
|
||||
// This is a NOP. Use http.Request.Context
|
||||
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
|
||||
|
||||
// test that we implement http.Flusher
|
||||
var _ http.Flusher = &responseWriter{}
|
||||
|
||||
// copied from http2/http2.go
|
||||
// bodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC 2616, section 4.4.
|
||||
func bodyAllowedForStatus(status int) bool {
|
||||
switch {
|
||||
case status >= 100 && status <= 199:
|
||||
return false
|
||||
case status == 204:
|
||||
return false
|
||||
case status == 304:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
package h2quic
|
||||
|
||||
import "net/http"
|
||||
|
||||
// The CloseNotifier is a deprecated interface, and staticcheck will report that from Go 1.11.
|
||||
// By defining it in a separate file, we can exclude this file from staticcheck.
|
||||
|
||||
// test that we implement http.CloseNotifier
|
||||
var _ http.CloseNotifier = &responseWriter{}
|
|
@ -0,0 +1,163 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockStream struct {
|
||||
id protocol.StreamID
|
||||
dataToRead bytes.Buffer
|
||||
dataWritten bytes.Buffer
|
||||
reset bool
|
||||
canceledWrite bool
|
||||
closed bool
|
||||
remoteClosed bool
|
||||
|
||||
unblockRead chan struct{}
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
var _ quic.Stream = &mockStream{}
|
||||
|
||||
func newMockStream(id protocol.StreamID) *mockStream {
|
||||
s := &mockStream{
|
||||
id: id,
|
||||
unblockRead: make(chan struct{}),
|
||||
}
|
||||
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil }
|
||||
func (s *mockStream) CancelRead(quic.ErrorCode) error { s.reset = true; return nil }
|
||||
func (s *mockStream) CancelWrite(quic.ErrorCode) error { s.canceledWrite = true; return nil }
|
||||
func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true; s.ctxCancel() }
|
||||
func (s mockStream) StreamID() protocol.StreamID { return s.id }
|
||||
func (s *mockStream) Context() context.Context { return s.ctx }
|
||||
func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") }
|
||||
func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") }
|
||||
func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") }
|
||||
|
||||
func (s *mockStream) Read(p []byte) (int, error) {
|
||||
n, _ := s.dataToRead.Read(p)
|
||||
if n == 0 { // block if there's no data
|
||||
<-s.unblockRead
|
||||
return 0, io.EOF
|
||||
}
|
||||
return n, nil // never return an EOF
|
||||
}
|
||||
func (s *mockStream) Write(p []byte) (int, error) { return s.dataWritten.Write(p) }
|
||||
|
||||
var _ = Describe("Response Writer", func() {
|
||||
var (
|
||||
w *responseWriter
|
||||
headerStream *mockStream
|
||||
dataStream *mockStream
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
headerStream = &mockStream{}
|
||||
dataStream = &mockStream{}
|
||||
w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5, utils.DefaultLogger)
|
||||
})
|
||||
|
||||
decodeHeaderFields := func() map[string][]string {
|
||||
fields := make(map[string][]string)
|
||||
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
|
||||
h2framer := http2.NewFramer(nil, bytes.NewReader(headerStream.dataWritten.Bytes()))
|
||||
|
||||
frame, err := h2framer.ReadFrame()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame).To(BeAssignableToTypeOf(&http2.HeadersFrame{}))
|
||||
hframe := frame.(*http2.HeadersFrame)
|
||||
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
|
||||
Expect(mhframe.StreamID).To(BeEquivalentTo(5))
|
||||
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for _, p := range mhframe.Fields {
|
||||
fields[p.Name] = append(fields[p.Name], p.Value)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
It("writes status", func() {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
fields := decodeHeaderFields()
|
||||
Expect(fields).To(HaveLen(1))
|
||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
|
||||
})
|
||||
|
||||
It("writes headers", func() {
|
||||
w.Header().Add("content-length", "42")
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
fields := decodeHeaderFields()
|
||||
Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"}))
|
||||
})
|
||||
|
||||
It("writes multiple headers with the same name", func() {
|
||||
const cookie1 = "test1=1; Max-Age=7200; path=/"
|
||||
const cookie2 = "test2=2; Max-Age=7200; path=/"
|
||||
w.Header().Add("set-cookie", cookie1)
|
||||
w.Header().Add("set-cookie", cookie2)
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
fields := decodeHeaderFields()
|
||||
Expect(fields).To(HaveKey("set-cookie"))
|
||||
cookies := fields["set-cookie"]
|
||||
Expect(cookies).To(ContainElement(cookie1))
|
||||
Expect(cookies).To(ContainElement(cookie2))
|
||||
})
|
||||
|
||||
It("writes data", func() {
|
||||
n, err := w.Write([]byte("foobar"))
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Should have written 200 on the header stream
|
||||
fields := decodeHeaderFields()
|
||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||
// And foobar on the data stream
|
||||
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("writes data after WriteHeader is called", func() {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
n, err := w.Write([]byte("foobar"))
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Should have written 418 on the header stream
|
||||
fields := decodeHeaderFields()
|
||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"418"}))
|
||||
// And foobar on the data stream
|
||||
Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("does not WriteHeader() twice", func() {
|
||||
w.WriteHeader(200)
|
||||
w.WriteHeader(500)
|
||||
fields := decodeHeaderFields()
|
||||
Expect(fields).To(HaveLen(1))
|
||||
Expect(fields).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||
})
|
||||
|
||||
It("doesn't allow writes if the status code doesn't allow a body", func() {
|
||||
w.WriteHeader(304)
|
||||
n, err := w.Write([]byte("foobar"))
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err).To(MatchError(http.ErrBodyNotAllowed))
|
||||
Expect(dataStream.dataWritten.Bytes()).To(HaveLen(0))
|
||||
})
|
||||
})
|
|
@ -0,0 +1,179 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
type roundTripCloser interface {
|
||||
http.RoundTripper
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// RoundTripper implements the http.RoundTripper interface
|
||||
type RoundTripper struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
// DisableCompression, if true, prevents the Transport from
|
||||
// requesting compression with an "Accept-Encoding: gzip"
|
||||
// request header when the Request contains no existing
|
||||
// Accept-Encoding value. If the Transport requests gzip on
|
||||
// its own and gets a gzipped response, it's transparently
|
||||
// decoded in the Response.Body. However, if the user
|
||||
// explicitly requested gzip it is not automatically
|
||||
// uncompressed.
|
||||
DisableCompression bool
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use with
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
// QuicConfig is the quic.Config used for dialing new connections.
|
||||
// If nil, reasonable default values will be used.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Dial specifies an optional dial function for creating QUIC
|
||||
// connections for requests.
|
||||
// If Dial is nil, quic.DialAddr will be used.
|
||||
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
||||
|
||||
clients map[string]roundTripCloser
|
||||
}
|
||||
|
||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||
type RoundTripOpt struct {
|
||||
// OnlyCachedConn controls whether the RoundTripper may
|
||||
// create a new QUIC connection. If set true and
|
||||
// no cached connection is available, RoundTrip
|
||||
// will return ErrNoCachedConn.
|
||||
OnlyCachedConn bool
|
||||
}
|
||||
|
||||
var _ roundTripCloser = &RoundTripper{}
|
||||
|
||||
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
||||
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
|
||||
|
||||
// RoundTripOpt is like RoundTrip, but takes options.
|
||||
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
if req.URL == nil {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("quic: nil Request.URL")
|
||||
}
|
||||
if req.URL.Host == "" {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("quic: no Host in request URL")
|
||||
}
|
||||
if req.Header == nil {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("quic: nil Request.Header")
|
||||
}
|
||||
|
||||
if req.URL.Scheme == "https" {
|
||||
for k, vv := range req.Header {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
closeRequestBody(req)
|
||||
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
|
||||
}
|
||||
|
||||
if req.Method != "" && !validMethod(req.Method) {
|
||||
closeRequestBody(req)
|
||||
return nil, fmt.Errorf("quic: invalid method %q", req.Method)
|
||||
}
|
||||
|
||||
hostname := authorityAddr("https", hostnameFromRequest(req))
|
||||
cl, err := r.getClient(hostname, opt.OnlyCachedConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cl.RoundTrip(req)
|
||||
}
|
||||
|
||||
// RoundTrip does a round trip.
|
||||
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return r.RoundTripOpt(req, RoundTripOpt{})
|
||||
}
|
||||
|
||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if r.clients == nil {
|
||||
r.clients = make(map[string]roundTripCloser)
|
||||
}
|
||||
|
||||
client, ok := r.clients[hostname]
|
||||
if !ok {
|
||||
if onlyCached {
|
||||
return nil, ErrNoCachedConn
|
||||
}
|
||||
client = newClient(
|
||||
hostname,
|
||||
r.TLSClientConfig,
|
||||
&roundTripperOpts{DisableCompression: r.DisableCompression},
|
||||
r.QuicConfig,
|
||||
r.Dial,
|
||||
)
|
||||
r.clients[hostname] = client
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Close closes the QUIC connections that this RoundTripper has used
|
||||
func (r *RoundTripper) Close() error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for _, client := range r.clients {
|
||||
if err := client.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.clients = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeRequestBody(req *http.Request) {
|
||||
if req.Body != nil {
|
||||
req.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func validMethod(method string) bool {
|
||||
/*
|
||||
Method = "OPTIONS" ; Section 9.2
|
||||
| "GET" ; Section 9.3
|
||||
| "HEAD" ; Section 9.4
|
||||
| "POST" ; Section 9.5
|
||||
| "PUT" ; Section 9.6
|
||||
| "DELETE" ; Section 9.7
|
||||
| "TRACE" ; Section 9.8
|
||||
| "CONNECT" ; Section 9.9
|
||||
| extension-method
|
||||
extension-method = token
|
||||
token = 1*<any CHAR except CTLs or separators>
|
||||
*/
|
||||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
||||
}
|
||||
|
||||
// copied from net/http/http.go
|
||||
func isNotToken(r rune) bool {
|
||||
return !httpguts.IsTokenRune(r)
|
||||
}
|
|
@ -0,0 +1,218 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{Request: req}, nil
|
||||
}
|
||||
func (m *mockClient) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ roundTripCloser = &mockClient{}
|
||||
|
||||
type mockBody struct {
|
||||
reader bytes.Reader
|
||||
readErr error
|
||||
closeErr error
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockBody) Read(p []byte) (int, error) {
|
||||
if m.readErr != nil {
|
||||
return 0, m.readErr
|
||||
}
|
||||
return m.reader.Read(p)
|
||||
}
|
||||
|
||||
func (m *mockBody) SetData(data []byte) {
|
||||
m.reader = *bytes.NewReader(data)
|
||||
}
|
||||
|
||||
func (m *mockBody) Close() error {
|
||||
m.closed = true
|
||||
return m.closeErr
|
||||
}
|
||||
|
||||
// make sure the mockBody can be used as a http.Request.Body
|
||||
var _ io.ReadCloser = &mockBody{}
|
||||
|
||||
var _ = Describe("RoundTripper", func() {
|
||||
var (
|
||||
rt *RoundTripper
|
||||
req1 *http.Request
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
rt = &RoundTripper{}
|
||||
var err error
|
||||
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("dialing hosts", func() {
|
||||
origDialAddr := dialAddr
|
||||
streamOpenErr := errors.New("error opening stream")
|
||||
|
||||
BeforeEach(func() {
|
||||
origDialAddr = dialAddr
|
||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
||||
// return an error when trying to open a stream
|
||||
// we don't want to test all the dial logic here, just that dialing happens at all
|
||||
return &mockSession{streamOpenErr: streamOpenErr}, nil
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
dialAddr = origDialAddr
|
||||
})
|
||||
|
||||
It("creates new clients", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError(streamOpenErr))
|
||||
Expect(rt.clients).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("uses the quic.Config, if provided", func() {
|
||||
config := &quic.Config{HandshakeTimeout: time.Millisecond}
|
||||
var receivedConfig *quic.Config
|
||||
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) {
|
||||
receivedConfig = config
|
||||
return nil, errors.New("err")
|
||||
}
|
||||
rt.QuicConfig = config
|
||||
rt.RoundTrip(req1)
|
||||
Expect(receivedConfig).To(Equal(config))
|
||||
})
|
||||
|
||||
It("uses the custom dialer, if provided", func() {
|
||||
var dialed bool
|
||||
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.Session, error) {
|
||||
dialed = true
|
||||
return nil, errors.New("err")
|
||||
}
|
||||
rt.Dial = dialer
|
||||
rt.RoundTrip(req1)
|
||||
Expect(dialed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("reuses existing clients", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError(streamOpenErr))
|
||||
Expect(rt.clients).To(HaveLen(1))
|
||||
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTrip(req2)
|
||||
Expect(err).To(MatchError(streamOpenErr))
|
||||
Expect(rt.clients).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true})
|
||||
Expect(err).To(MatchError(ErrNoCachedConn))
|
||||
})
|
||||
})
|
||||
|
||||
Context("validating request", func() {
|
||||
It("rejects plain HTTP requests", func() {
|
||||
req, err := http.NewRequest("GET", "http://www.example.org/", nil)
|
||||
req.Body = &mockBody{}
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError("quic: unsupported protocol scheme: http"))
|
||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects requests without a URL", func() {
|
||||
req1.URL = nil
|
||||
req1.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err).To(MatchError("quic: nil Request.URL"))
|
||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects request without a URL Host", func() {
|
||||
req1.URL.Host = ""
|
||||
req1.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err).To(MatchError("quic: no Host in request URL"))
|
||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't try to close the body if the request doesn't have one", func() {
|
||||
req1.URL = nil
|
||||
Expect(req1.Body).To(BeNil())
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err).To(MatchError("quic: nil Request.URL"))
|
||||
})
|
||||
|
||||
It("rejects requests without a header", func() {
|
||||
req1.Header = nil
|
||||
req1.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err).To(MatchError("quic: nil Request.Header"))
|
||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects requests with invalid header name fields", func() {
|
||||
req1.Header.Add("foobär", "value")
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err).To(MatchError("quic: invalid http header field name \"foobär\""))
|
||||
})
|
||||
|
||||
It("rejects requests with invalid header name values", func() {
|
||||
req1.Header.Add("foo", string([]byte{0x7}))
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err.Error()).To(ContainSubstring("quic: invalid http header field value"))
|
||||
})
|
||||
|
||||
It("rejects requests with an invalid request method", func() {
|
||||
req1.Method = "foobär"
|
||||
req1.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err).To(MatchError("quic: invalid method \"foobär\""))
|
||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
It("closes", func() {
|
||||
rt.clients = make(map[string]roundTripCloser)
|
||||
cl := &mockClient{}
|
||||
rt.clients["foo.bar"] = cl
|
||||
err := rt.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(rt.clients)).To(BeZero())
|
||||
Expect(cl.closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("closes a RoundTripper that has never been used", func() {
|
||||
Expect(len(rt.clients)).To(BeZero())
|
||||
err := rt.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(rt.clients)).To(BeZero())
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,402 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
type streamCreator interface {
|
||||
quic.Session
|
||||
GetOrOpenStream(protocol.StreamID) (quic.Stream, error)
|
||||
}
|
||||
|
||||
type remoteCloser interface {
|
||||
CloseRemote(protocol.ByteCount)
|
||||
}
|
||||
|
||||
// allows mocking of quic.Listen and quic.ListenAddr
|
||||
var (
|
||||
quicListen = quic.Listen
|
||||
quicListenAddr = quic.ListenAddr
|
||||
)
|
||||
|
||||
// Server is a HTTP2 server listening for QUIC connections.
|
||||
type Server struct {
|
||||
*http.Server
|
||||
|
||||
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
|
||||
// If nil, it uses reasonable default values.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Private flag for demo, do not use
|
||||
CloseAfterFirstRequest bool
|
||||
|
||||
port uint32 // used atomically
|
||||
|
||||
listenerMutex sync.Mutex
|
||||
listener quic.Listener
|
||||
closed bool
|
||||
|
||||
supportedVersionsAsString string
|
||||
|
||||
logger utils.Logger // will be set by Server.serveImpl()
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
||||
func (s *Server) ListenAndServe() error {
|
||||
if s.Server == nil {
|
||||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
return s.serveImpl(s.TLSConfig, nil)
|
||||
}
|
||||
|
||||
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
|
||||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We currently only use the cert-related stuff from tls.Config,
|
||||
// so we don't need to make a full copy.
|
||||
config := &tls.Config{
|
||||
Certificates: certs,
|
||||
}
|
||||
return s.serveImpl(config, nil)
|
||||
}
|
||||
|
||||
// Serve an existing UDP connection.
|
||||
func (s *Server) Serve(conn net.PacketConn) error {
|
||||
return s.serveImpl(s.TLSConfig, conn)
|
||||
}
|
||||
|
||||
func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
|
||||
if s.Server == nil {
|
||||
return errors.New("use of h2quic.Server without http.Server")
|
||||
}
|
||||
s.logger = utils.DefaultLogger.WithPrefix("server")
|
||||
s.listenerMutex.Lock()
|
||||
if s.closed {
|
||||
s.listenerMutex.Unlock()
|
||||
return errors.New("Server is already closed")
|
||||
}
|
||||
if s.listener != nil {
|
||||
s.listenerMutex.Unlock()
|
||||
return errors.New("ListenAndServe may only be called once")
|
||||
}
|
||||
|
||||
var ln quic.Listener
|
||||
var err error
|
||||
if conn == nil {
|
||||
ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
|
||||
} else {
|
||||
ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
|
||||
}
|
||||
if err != nil {
|
||||
s.listenerMutex.Unlock()
|
||||
return err
|
||||
}
|
||||
s.listener = ln
|
||||
s.listenerMutex.Unlock()
|
||||
|
||||
for {
|
||||
sess, err := ln.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.handleHeaderStream(sess.(streamCreator))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleHeaderStream(session streamCreator) {
|
||||
stream, err := session.AcceptStream()
|
||||
if err != nil {
|
||||
session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
|
||||
return
|
||||
}
|
||||
|
||||
hpackDecoder := hpack.NewDecoder(4096, nil)
|
||||
h2framer := http2.NewFramer(nil, stream)
|
||||
|
||||
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
|
||||
for {
|
||||
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
|
||||
// QuicErrors must originate from stream.Read() returning an error.
|
||||
// In this case, the session has already logged the error, so we don't
|
||||
// need to log it again.
|
||||
errorCode := qerr.InternalError
|
||||
if qerr, ok := err.(*qerr.QuicError); !ok {
|
||||
errorCode = qerr.ErrorCode
|
||||
s.logger.Errorf("error handling h2 request: %s", err.Error())
|
||||
}
|
||||
session.CloseWithError(quic.ErrorCode(errorCode), err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
|
||||
h2frame, err := h2framer.ReadFrame()
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
|
||||
}
|
||||
var h2headersFrame *http2.HeadersFrame
|
||||
switch f := h2frame.(type) {
|
||||
case *http2.PriorityFrame:
|
||||
// ignore PRIORITY frames
|
||||
s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
|
||||
return nil
|
||||
case *http2.HeadersFrame:
|
||||
h2headersFrame = f
|
||||
default:
|
||||
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
|
||||
}
|
||||
|
||||
if !h2headersFrame.HeadersEnded() {
|
||||
return errors.New("http2 header continuation not implemented")
|
||||
}
|
||||
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
|
||||
if err != nil {
|
||||
s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := requestFromHeaders(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.logger.Debug() {
|
||||
s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
|
||||
} else {
|
||||
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
|
||||
}
|
||||
|
||||
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request
|
||||
if dataStream == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleRequest should be as non-blocking as possible to minimize
|
||||
// head-of-line blocking. Potentially blocking code is run in a separate
|
||||
// goroutine, enabling handleRequest to return before the code is executed.
|
||||
go func() {
|
||||
streamEnded := h2headersFrame.StreamEnded()
|
||||
if streamEnded {
|
||||
dataStream.(remoteCloser).CloseRemote(0)
|
||||
streamEnded = true
|
||||
_, _ = dataStream.Read([]byte{0}) // read the eof
|
||||
}
|
||||
|
||||
req = req.WithContext(dataStream.Context())
|
||||
reqBody := newRequestBody(dataStream)
|
||||
req.Body = reqBody
|
||||
|
||||
req.RemoteAddr = session.RemoteAddr().String()
|
||||
|
||||
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
|
||||
|
||||
handler := s.Handler
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
panicked := false
|
||||
func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
// Copied from net/http/server.go
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
|
||||
panicked = true
|
||||
}
|
||||
}()
|
||||
handler.ServeHTTP(responseWriter, req)
|
||||
}()
|
||||
if panicked {
|
||||
responseWriter.WriteHeader(500)
|
||||
} else {
|
||||
responseWriter.WriteHeader(200)
|
||||
}
|
||||
if responseWriter.dataStream != nil {
|
||||
if !streamEnded && !reqBody.requestRead {
|
||||
// in gQUIC, the error code doesn't matter, so just use 0 here
|
||||
responseWriter.dataStream.CancelRead(0)
|
||||
}
|
||||
responseWriter.dataStream.Close()
|
||||
}
|
||||
if s.CloseAfterFirstRequest {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
session.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
|
||||
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
||||
func (s *Server) Close() error {
|
||||
s.listenerMutex.Lock()
|
||||
defer s.listenerMutex.Unlock()
|
||||
s.closed = true
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
|
||||
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
||||
func (s *Server) CloseGracefully(timeout time.Duration) error {
|
||||
// TODO: implement
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
|
||||
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
|
||||
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
|
||||
func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
||||
port := atomic.LoadUint32(&s.port)
|
||||
|
||||
if port == 0 {
|
||||
// Extract port from s.Server.Addr
|
||||
_, portStr, err := net.SplitHostPort(s.Server.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
portInt, err := net.LookupPort("tcp", portStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
port = uint32(portInt)
|
||||
atomic.StoreUint32(&s.port, port)
|
||||
}
|
||||
|
||||
if s.supportedVersionsAsString == "" {
|
||||
var versions []string
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
versions = append(versions, v.ToAltSvc())
|
||||
}
|
||||
s.supportedVersionsAsString = strings.Join(versions, ",")
|
||||
}
|
||||
|
||||
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListenAndServeQUIC listens on the UDP network address addr and calls the
|
||||
// handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is
|
||||
// used when handler is nil.
|
||||
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
server := &Server{
|
||||
Server: &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
},
|
||||
}
|
||||
return server.ListenAndServeTLS(certFile, keyFile)
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the given network address for both, TLS and QUIC
|
||||
// connetions in parallel. It returns if one of the two returns an error.
|
||||
// http.DefaultServeMux is used when handler is nil.
|
||||
// The correct Alt-Svc headers for QUIC are set.
|
||||
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
// Load certs
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We currently only use the cert-related stuff from tls.Config,
|
||||
// so we don't need to make a full copy.
|
||||
config := &tls.Config{
|
||||
Certificates: certs,
|
||||
}
|
||||
|
||||
// Open the listeners
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tcpConn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tcpConn.Close()
|
||||
|
||||
tlsConn := tls.NewListener(tcpConn, config)
|
||||
defer tlsConn.Close()
|
||||
|
||||
// Start the servers
|
||||
httpServer := &http.Server{
|
||||
Addr: addr,
|
||||
TLSConfig: config,
|
||||
}
|
||||
|
||||
quicServer := &Server{
|
||||
Server: httpServer,
|
||||
}
|
||||
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
quicServer.SetQuicHeaders(w.Header())
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
hErr := make(chan error)
|
||||
qErr := make(chan error)
|
||||
go func() {
|
||||
hErr <- httpServer.Serve(tlsConn)
|
||||
}()
|
||||
go func() {
|
||||
qErr <- quicServer.Serve(udpConn)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-hErr:
|
||||
quicServer.Close()
|
||||
return err
|
||||
case err := <-qErr:
|
||||
// Cannot close the HTTP server or wait for requests to complete properly :/
|
||||
return err
|
||||
}
|
||||
}
|
|
@ -0,0 +1,536 @@
|
|||
package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockSession struct {
|
||||
closed bool
|
||||
closedWithError error
|
||||
dataStream quic.Stream
|
||||
streamToAccept quic.Stream
|
||||
streamsToOpen []quic.Stream
|
||||
blockOpenStreamSync bool
|
||||
blockOpenStreamChan chan struct{} // close this chan (or call Close) to make OpenStreamSync return
|
||||
streamOpenErr error
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newMockSession() *mockSession {
|
||||
return &mockSession{blockOpenStreamChan: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (quic.Stream, error) {
|
||||
return s.dataStream, nil
|
||||
}
|
||||
func (s *mockSession) AcceptStream() (quic.Stream, error) { return s.streamToAccept, nil }
|
||||
func (s *mockSession) OpenStream() (quic.Stream, error) {
|
||||
if s.streamOpenErr != nil {
|
||||
return nil, s.streamOpenErr
|
||||
}
|
||||
str := s.streamsToOpen[0]
|
||||
s.streamsToOpen = s.streamsToOpen[1:]
|
||||
return str, nil
|
||||
}
|
||||
func (s *mockSession) OpenStreamSync() (quic.Stream, error) {
|
||||
if s.blockOpenStreamSync {
|
||||
<-s.blockOpenStreamChan
|
||||
}
|
||||
return s.OpenStream()
|
||||
}
|
||||
func (s *mockSession) Close() error {
|
||||
s.ctxCancel()
|
||||
if !s.closed {
|
||||
close(s.blockOpenStreamChan)
|
||||
}
|
||||
s.closed = true
|
||||
return nil
|
||||
}
|
||||
func (s *mockSession) CloseWithError(_ quic.ErrorCode, e error) error {
|
||||
s.closedWithError = e
|
||||
return s.Close()
|
||||
}
|
||||
func (s *mockSession) LocalAddr() net.Addr {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (s *mockSession) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
|
||||
}
|
||||
func (s *mockSession) Context() context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
func (s *mockSession) ConnectionState() quic.ConnectionState { panic("not implemented") }
|
||||
func (s *mockSession) AcceptUniStream() (quic.ReceiveStream, error) { panic("not implemented") }
|
||||
func (s *mockSession) OpenUniStream() (quic.SendStream, error) { panic("not implemented") }
|
||||
func (s *mockSession) OpenUniStreamSync() (quic.SendStream, error) { panic("not implemented") }
|
||||
|
||||
var _ = Describe("H2 server", func() {
|
||||
var (
|
||||
s *Server
|
||||
session *mockSession
|
||||
dataStream *mockStream
|
||||
origQuicListenAddr = quicListenAddr
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
s = &Server{
|
||||
Server: &http.Server{
|
||||
TLSConfig: testdata.GetTLSConfig(),
|
||||
},
|
||||
logger: utils.DefaultLogger,
|
||||
}
|
||||
dataStream = newMockStream(0)
|
||||
close(dataStream.unblockRead)
|
||||
session = newMockSession()
|
||||
session.dataStream = dataStream
|
||||
session.ctx, session.ctxCancel = context.WithCancel(context.Background())
|
||||
origQuicListenAddr = quicListenAddr
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
quicListenAddr = origQuicListenAddr
|
||||
})
|
||||
|
||||
Context("handling requests", func() {
|
||||
var (
|
||||
h2framer *http2.Framer
|
||||
hpackDecoder *hpack.Decoder
|
||||
headerStream *mockStream
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
headerStream = &mockStream{}
|
||||
hpackDecoder = hpack.NewDecoder(4096, nil)
|
||||
h2framer = http2.NewFramer(nil, headerStream)
|
||||
})
|
||||
|
||||
It("handles a sample GET request", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
Expect(r.RemoteAddr).To(Equal("127.0.0.1:42"))
|
||||
handlerCalled = true
|
||||
})
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
Expect(dataStream.remoteClosed).To(BeTrue())
|
||||
Expect(dataStream.reset).To(BeFalse())
|
||||
})
|
||||
|
||||
It("returns 200 with an empty handler", func() {
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() []byte {
|
||||
return headerStream.dataWritten.Bytes()
|
||||
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88})) // 0x88 is 200
|
||||
})
|
||||
|
||||
It("correctly handles a panicking handler", func() {
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("foobar")
|
||||
})
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() []byte {
|
||||
return headerStream.dataWritten.Bytes()
|
||||
}).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x8e})) // 0x82 is 500
|
||||
})
|
||||
|
||||
It("resets the dataStream when client sends a body in GET request", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
handlerCalled = true
|
||||
})
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
|
||||
Expect(dataStream.remoteClosed).To(BeFalse())
|
||||
})
|
||||
|
||||
It("resets the dataStream when the body of POST request is not read", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
Expect(r.Method).To(Equal("POST"))
|
||||
handlerCalled = true
|
||||
})
|
||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
|
||||
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
|
||||
Expect(handlerCalled).To(BeTrue())
|
||||
})
|
||||
|
||||
It("handles a request for which the client immediately resets the data stream", func() {
|
||||
session.dataStream = nil
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
})
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("resets the dataStream when the body of POST request is not read, and the request handler replaces the request.Body", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}{}
|
||||
handlerCalled = true
|
||||
})
|
||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return dataStream.reset }).Should(BeTrue())
|
||||
Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse())
|
||||
Expect(handlerCalled).To(BeTrue())
|
||||
})
|
||||
|
||||
It("closes the dataStream if the body of POST request was read", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
Expect(r.Method).To(Equal("POST"))
|
||||
handlerCalled = true
|
||||
// read the request body
|
||||
b := make([]byte, 1000)
|
||||
n, _ := r.Body.Read(b)
|
||||
Expect(n).ToNot(BeZero())
|
||||
})
|
||||
headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7})
|
||||
dataStream.dataToRead.Write([]byte("foo=bar"))
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
Expect(dataStream.reset).To(BeFalse())
|
||||
})
|
||||
|
||||
It("ignores PRIORITY frames", func() {
|
||||
handlerCalled := make(chan struct{})
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(handlerCalled)
|
||||
})
|
||||
buf := &bytes.Buffer{}
|
||||
framer := http2.NewFramer(buf, nil)
|
||||
err := framer.WritePriority(10, http2.PriorityParam{Weight: 42})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(buf.Bytes()).ToNot(BeEmpty())
|
||||
headerStream.dataToRead.Write(buf.Bytes())
|
||||
err = s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Consistently(handlerCalled).ShouldNot(BeClosed())
|
||||
Expect(dataStream.reset).To(BeFalse())
|
||||
Expect(dataStream.closed).To(BeFalse())
|
||||
})
|
||||
|
||||
It("errors when non-header frames are received", func() {
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x06, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5,
|
||||
'f', 'o', 'o', 'b', 'a', 'r',
|
||||
})
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).To(MatchError("InvalidHeadersStreamData: expected a header frame"))
|
||||
})
|
||||
|
||||
It("Cancels the request context when the datstream is closed", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
err := r.Context().Err()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal("context canceled"))
|
||||
handlerCalled = true
|
||||
})
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
dataStream.Close()
|
||||
err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
Expect(dataStream.remoteClosed).To(BeTrue())
|
||||
Expect(dataStream.reset).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
It("handles the header stream", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
handlerCalled = true
|
||||
})
|
||||
headerStream := &mockStream{id: 3}
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
session.streamToAccept = headerStream
|
||||
go s.handleHeaderStream(session)
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("closes the connection if it encounters an error on the header stream", func() {
|
||||
var handlerCalled bool
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlerCalled = true
|
||||
})
|
||||
headerStream := &mockStream{id: 3}
|
||||
headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100))
|
||||
session.streamToAccept = headerStream
|
||||
go s.handleHeaderStream(session)
|
||||
Consistently(func() bool { return handlerCalled }).Should(BeFalse())
|
||||
Eventually(func() bool { return session.closed }).Should(BeTrue())
|
||||
Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")))
|
||||
})
|
||||
|
||||
It("supports closing after first request", func() {
|
||||
s.CloseAfterFirstRequest = true
|
||||
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
headerStream := &mockStream{id: 3}
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
session.streamToAccept = headerStream
|
||||
Expect(session.closed).To(BeFalse())
|
||||
go s.handleHeaderStream(session)
|
||||
Eventually(func() bool { return session.closed }).Should(BeTrue())
|
||||
})
|
||||
|
||||
It("uses the default handler as fallback", func() {
|
||||
var handlerCalled bool
|
||||
http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.Host).To(Equal("www.example.com"))
|
||||
handlerCalled = true
|
||||
}))
|
||||
headerStream := &mockStream{id: 3}
|
||||
headerStream.dataToRead.Write([]byte{
|
||||
0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5,
|
||||
// Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding
|
||||
0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff,
|
||||
})
|
||||
session.streamToAccept = headerStream
|
||||
go s.handleHeaderStream(session)
|
||||
Eventually(func() bool { return handlerCalled }).Should(BeTrue())
|
||||
})
|
||||
|
||||
Context("setting http headers", func() {
|
||||
var expected http.Header
|
||||
|
||||
getExpectedHeader := func(versions []protocol.VersionNumber) http.Header {
|
||||
var versionsAsString []string
|
||||
for _, v := range versions {
|
||||
versionsAsString = append(versionsAsString, v.ToAltSvc())
|
||||
}
|
||||
return http.Header{
|
||||
"Alt-Svc": {fmt.Sprintf(`quic=":443"; ma=2592000; v="%s"`, strings.Join(versionsAsString, ","))},
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
Expect(getExpectedHeader([]protocol.VersionNumber{99, 90, 9})).To(Equal(http.Header{"Alt-Svc": {`quic=":443"; ma=2592000; v="99,90,9"`}}))
|
||||
expected = getExpectedHeader(protocol.SupportedVersions)
|
||||
})
|
||||
|
||||
It("sets proper headers with numeric port", func() {
|
||||
s.Server.Addr = ":443"
|
||||
hdr := http.Header{}
|
||||
err := s.SetQuicHeaders(hdr)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(hdr).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("sets proper headers with full addr", func() {
|
||||
s.Server.Addr = "127.0.0.1:443"
|
||||
hdr := http.Header{}
|
||||
err := s.SetQuicHeaders(hdr)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(hdr).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("sets proper headers with string port", func() {
|
||||
s.Server.Addr = ":https"
|
||||
hdr := http.Header{}
|
||||
err := s.SetQuicHeaders(hdr)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(hdr).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("works multiple times", func() {
|
||||
s.Server.Addr = ":https"
|
||||
hdr := http.Header{}
|
||||
err := s.SetQuicHeaders(hdr)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(hdr).To(Equal(expected))
|
||||
hdr = http.Header{}
|
||||
err = s.SetQuicHeaders(hdr)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(hdr).To(Equal(expected))
|
||||
})
|
||||
})
|
||||
|
||||
It("should error when ListenAndServe is called with s.Server nil", func() {
|
||||
err := (&Server{}).ListenAndServe()
|
||||
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
||||
})
|
||||
|
||||
It("should error when ListenAndServeTLS is called with s.Server nil", func() {
|
||||
err := (&Server{}).ListenAndServeTLS(testdata.GetCertificatePaths())
|
||||
Expect(err).To(MatchError("use of h2quic.Server without http.Server"))
|
||||
})
|
||||
|
||||
It("should nop-Close() when s.server is nil", func() {
|
||||
err := (&Server{}).Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors when ListenAndServer is called after Close", func() {
|
||||
serv := &Server{Server: &http.Server{}}
|
||||
Expect(serv.Close()).To(Succeed())
|
||||
err := serv.ListenAndServe()
|
||||
Expect(err).To(MatchError("Server is already closed"))
|
||||
})
|
||||
|
||||
Context("ListenAndServe", func() {
|
||||
BeforeEach(func() {
|
||||
s.Server.Addr = "localhost:0"
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(s.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
It("may only be called once", func() {
|
||||
cErr := make(chan error)
|
||||
for i := 0; i < 2; i++ {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := s.ListenAndServe()
|
||||
if err != nil {
|
||||
cErr <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
err := <-cErr
|
||||
Expect(err).To(MatchError("ListenAndServe may only be called once"))
|
||||
err = s.Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}, 0.5)
|
||||
|
||||
It("uses the quic.Config to start the quic server", func() {
|
||||
conf := &quic.Config{HandshakeTimeout: time.Nanosecond}
|
||||
var receivedConf *quic.Config
|
||||
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
|
||||
receivedConf = config
|
||||
return nil, errors.New("listen err")
|
||||
}
|
||||
s.QuicConfig = conf
|
||||
go s.ListenAndServe()
|
||||
Eventually(func() *quic.Config { return receivedConf }).Should(Equal(conf))
|
||||
})
|
||||
})
|
||||
|
||||
Context("ListenAndServeTLS", func() {
|
||||
BeforeEach(func() {
|
||||
s.Server.Addr = "localhost:0"
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
err := s.Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("may only be called once", func() {
|
||||
cErr := make(chan error)
|
||||
for i := 0; i < 2; i++ {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := s.ListenAndServeTLS(testdata.GetCertificatePaths())
|
||||
if err != nil {
|
||||
cErr <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
err := <-cErr
|
||||
Expect(err).To(MatchError("ListenAndServe may only be called once"))
|
||||
err = s.Close()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}, 0.5)
|
||||
})
|
||||
|
||||
It("closes gracefully", func() {
|
||||
err := s.CloseGracefully(0)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors when listening fails", func() {
|
||||
testErr := errors.New("listen error")
|
||||
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
fullpem, privkey := testdata.GetCertificatePaths()
|
||||
err := ListenAndServeQUIC("", fullpem, privkey, nil)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
|
@ -0,0 +1,210 @@
|
|||
package chrome_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gexec"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
dataLen = 500 * 1024 // 500 KB
|
||||
dataLongLen = 50 * 1024 * 1024 // 50 MB
|
||||
)
|
||||
|
||||
var (
|
||||
nFilesUploaded int32 // should be used atomically
|
||||
doneCalled utils.AtomicBool
|
||||
)
|
||||
|
||||
func TestChrome(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Chrome Suite")
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Requires the len & num GET parameters, e.g. /uploadtest?len=100&num=1
|
||||
http.HandleFunc("/uploadtest", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
response := uploadHTML
|
||||
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
|
||||
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
|
||||
_, err := io.WriteString(w, response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
// Requires the len & num GET parameters, e.g. /downloadtest?len=100&num=1
|
||||
http.HandleFunc("/downloadtest", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
response := downloadHTML
|
||||
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
|
||||
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
|
||||
_, err := io.WriteString(w, response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
http.HandleFunc("/uploadhandler", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
|
||||
l, err := strconv.Atoi(r.URL.Query().Get("len"))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
defer r.Body.Close()
|
||||
actual, err := ioutil.ReadAll(r.Body)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(bytes.Equal(actual, testserver.GeneratePRData(l))).To(BeTrue())
|
||||
|
||||
atomic.AddInt32(&nFilesUploaded, 1)
|
||||
})
|
||||
|
||||
http.HandleFunc("/done", func(w http.ResponseWriter, r *http.Request) {
|
||||
doneCalled.Set(true)
|
||||
})
|
||||
}
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
testserver.StopQuicServer()
|
||||
|
||||
atomic.StoreInt32(&nFilesUploaded, 0)
|
||||
doneCalled.Set(false)
|
||||
})
|
||||
|
||||
func getChromePath() string {
|
||||
if runtime.GOOS == "darwin" {
|
||||
return "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
|
||||
}
|
||||
if path, err := exec.LookPath("google-chrome"); err == nil {
|
||||
return path
|
||||
}
|
||||
if path, err := exec.LookPath("chromium-browser"); err == nil {
|
||||
return path
|
||||
}
|
||||
Fail("No Chrome executable found.")
|
||||
return ""
|
||||
}
|
||||
|
||||
func chromeTest(version protocol.VersionNumber, url string, blockUntilDone func()) {
|
||||
userDataDir, err := ioutil.TempDir("", "quic-go-test-chrome-dir")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(userDataDir)
|
||||
path := getChromePath()
|
||||
args := []string{
|
||||
"--disable-gpu",
|
||||
"--no-first-run=true",
|
||||
"--no-default-browser-check=true",
|
||||
"--user-data-dir=" + userDataDir,
|
||||
"--enable-quic=true",
|
||||
"--no-proxy-server=true",
|
||||
"--no-sandbox",
|
||||
"--origin-to-force-quic-on=quic.clemente.io:443",
|
||||
fmt.Sprintf(`--host-resolver-rules=MAP quic.clemente.io:443 127.0.0.1:%s`, testserver.Port()),
|
||||
fmt.Sprintf("--quic-version=QUIC_VERSION_%s", version.ToAltSvc()),
|
||||
url,
|
||||
}
|
||||
utils.DefaultLogger.Infof("Running chrome: %s '%s'", getChromePath(), strings.Join(args, "' '"))
|
||||
command := exec.Command(path, args...)
|
||||
session, err := gexec.Start(command, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
blockUntilDone()
|
||||
}
|
||||
|
||||
func waitForDone() {
|
||||
Eventually(func() bool { return doneCalled.Get() }, 60).Should(BeTrue())
|
||||
}
|
||||
|
||||
func waitForNUploaded(expected int) func() {
|
||||
return func() {
|
||||
Eventually(func() int32 {
|
||||
return atomic.LoadInt32(&nFilesUploaded)
|
||||
}, 60).Should(BeEquivalentTo(expected))
|
||||
}
|
||||
}
|
||||
|
||||
const commonJS = `
|
||||
var buf = new ArrayBuffer(LENGTH);
|
||||
var prng = new Uint8Array(buf);
|
||||
var seed = 1;
|
||||
for (var i = 0; i < LENGTH; i++) {
|
||||
// https://en.wikipedia.org/wiki/Lehmer_random_number_generator
|
||||
seed = seed * 48271 % 2147483647;
|
||||
prng[i] = seed;
|
||||
}
|
||||
`
|
||||
|
||||
const uploadHTML = `
|
||||
<html>
|
||||
<body>
|
||||
<script>
|
||||
console.log("Running DL test...");
|
||||
|
||||
` + commonJS + `
|
||||
for (var i = 0; i < NUM; i++) {
|
||||
var req = new XMLHttpRequest();
|
||||
req.open("POST", "/uploadhandler?len=" + LENGTH, true);
|
||||
req.send(buf);
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
const downloadHTML = `
|
||||
<html>
|
||||
<body>
|
||||
<script>
|
||||
console.log("Running DL test...");
|
||||
` + commonJS + `
|
||||
|
||||
function verify(data) {
|
||||
if (data.length !== LENGTH) return false;
|
||||
for (var i = 0; i < LENGTH; i++) {
|
||||
if (data[i] !== prng[i]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
var nOK = 0;
|
||||
for (var i = 0; i < NUM; i++) {
|
||||
let req = new XMLHttpRequest();
|
||||
req.responseType = "arraybuffer";
|
||||
req.open("POST", "/prdata?len=" + LENGTH, true);
|
||||
req.onreadystatechange = function () {
|
||||
if (req.readyState === XMLHttpRequest.DONE && req.status === 200) {
|
||||
if (verify(new Uint8Array(req.response))) {
|
||||
nOK++;
|
||||
if (nOK === NUM) {
|
||||
console.log("Done :)");
|
||||
var reqDone = new XMLHttpRequest();
|
||||
reqDone.open("GET", "/done");
|
||||
reqDone.send();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
req.send();
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`
|
|
@ -0,0 +1,76 @@
|
|||
package chrome_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
)
|
||||
|
||||
var _ = Describe("Chrome tests", func() {
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
// TODO: activate Chrome integration tests with gQUIC 44
|
||||
if version == protocol.Version44 {
|
||||
continue
|
||||
}
|
||||
|
||||
Context(fmt.Sprintf("with version %s", version), func() {
|
||||
JustBeforeEach(func() {
|
||||
testserver.StartQuicServer([]protocol.VersionNumber{version})
|
||||
})
|
||||
|
||||
It("downloads a small file", func() {
|
||||
chromeTest(
|
||||
version,
|
||||
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLen),
|
||||
waitForDone,
|
||||
)
|
||||
})
|
||||
|
||||
It("downloads a large file", func() {
|
||||
chromeTest(
|
||||
version,
|
||||
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLongLen),
|
||||
waitForDone,
|
||||
)
|
||||
})
|
||||
|
||||
It("loads a large number of files", func() {
|
||||
chromeTest(
|
||||
version,
|
||||
"https://quic.clemente.io/downloadtest?num=4&len=100",
|
||||
waitForDone,
|
||||
)
|
||||
})
|
||||
|
||||
It("uploads a small file", func() {
|
||||
chromeTest(
|
||||
version,
|
||||
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLen),
|
||||
waitForNUploaded(1),
|
||||
)
|
||||
})
|
||||
|
||||
It("uploads a large file", func() {
|
||||
chromeTest(
|
||||
version,
|
||||
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLongLen),
|
||||
waitForNUploaded(1),
|
||||
)
|
||||
})
|
||||
|
||||
It("uploads many small files", func() {
|
||||
num := protocol.DefaultMaxIncomingStreams + 20
|
||||
chromeTest(
|
||||
version,
|
||||
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=%d&len=%d", num, dataLen),
|
||||
waitForNUploaded(num),
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,137 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
mrand "math/rand"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
. "github.com/onsi/gomega/gbytes"
|
||||
. "github.com/onsi/gomega/gexec"
|
||||
)
|
||||
|
||||
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
|
||||
|
||||
var _ = Describe("Drop tests", func() {
|
||||
var proxy *quicproxy.QuicProxy
|
||||
|
||||
startProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) {
|
||||
var err error
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
|
||||
RemoteAddr: "localhost:" + testserver.Port(),
|
||||
DropPacket: dropCallback,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
downloadFile := func(version protocol.VersionNumber) {
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
||||
"https://quic.clemente.io/prdata",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 20).Should(Exit(0))
|
||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
||||
}
|
||||
|
||||
downloadHello := func(version protocol.VersionNumber) {
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
||||
"https://quic.clemente.io/hello",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 20).Should(Exit(0))
|
||||
Expect(session.Out).To(Say(":status 200"))
|
||||
Expect(session.Out).To(Say("body: Hello, World!\n"))
|
||||
}
|
||||
|
||||
deterministicDropper := func(p, interval, dropInARow uint64) bool {
|
||||
return (p % interval) < dropInARow
|
||||
}
|
||||
|
||||
stochasticDropper := func(freq int) bool {
|
||||
return mrand.Int63n(int64(freq)) == 0
|
||||
}
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
version := v
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
Context("during the crypto handshake", func() {
|
||||
for _, d := range directions {
|
||||
direction := d
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p == 1 && d.Is(direction)
|
||||
}, version)
|
||||
downloadHello(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p == 2 && d.Is(direction)
|
||||
}, version)
|
||||
downloadHello(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return d.Is(direction) && stochasticDropper(5)
|
||||
}, version)
|
||||
downloadHello(version)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
Context("after the crypto handshake", func() {
|
||||
for _, d := range directions {
|
||||
direction := d
|
||||
|
||||
It(fmt.Sprintf("downloads a file when every 5th packet is dropped in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p >= 10 && d.Is(direction) && deterministicDropper(p, 5, 1)
|
||||
}, version)
|
||||
downloadFile(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("downloads a file when 1/5th of all packet are dropped randomly in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p >= 10 && d.Is(direction) && stochasticDropper(5)
|
||||
}, version)
|
||||
downloadFile(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("downloads a file when 10 packets every 100 packet are dropped in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p >= 10 && d.Is(direction) && deterministicDropper(p, 100, 10)
|
||||
}, version)
|
||||
downloadFile(version)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,45 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
clientPath string
|
||||
serverPath string
|
||||
)
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "GQuic Tests Suite")
|
||||
}
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
rand.Seed(GinkgoRandomSeed())
|
||||
})
|
||||
|
||||
var _ = JustBeforeEach(func() {
|
||||
testserver.StartQuicServer(nil)
|
||||
})
|
||||
|
||||
var _ = AfterEach(testserver.StopQuicServer)
|
||||
|
||||
func init() {
|
||||
_, thisfile, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
panic("Failed to get current path")
|
||||
}
|
||||
clientPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/client-%s-debug", runtime.GOOS))
|
||||
serverPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/server-%s-debug", runtime.GOOS))
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
. "github.com/onsi/gomega/gbytes"
|
||||
. "github.com/onsi/gomega/gexec"
|
||||
)
|
||||
|
||||
var _ = Describe("Integration tests", func() {
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
It("gets a simple file", func() {
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+testserver.Port(),
|
||||
"https://quic.clemente.io/hello",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 5).Should(Exit(0))
|
||||
Expect(session.Out).To(Say(":status 200"))
|
||||
Expect(session.Out).To(Say("body: Hello, World!\n"))
|
||||
})
|
||||
|
||||
It("posts and reads a body", func() {
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+testserver.Port(),
|
||||
"--body=foo",
|
||||
"https://quic.clemente.io/echo",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 5).Should(Exit(0))
|
||||
Expect(session.Out).To(Say(":status 200"))
|
||||
Expect(session.Out).To(Say("body: foo\n"))
|
||||
})
|
||||
|
||||
It("gets a file", func() {
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+testserver.Port(),
|
||||
"https://quic.clemente.io/prdata",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 10).Should(Exit(0))
|
||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("gets many copies of a file in parallel", func() {
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer GinkgoRecover()
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+testserver.Port(),
|
||||
"https://quic.clemente.io/prdata",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 20).Should(Exit(0))
|
||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,95 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
. "github.com/onsi/gomega/gexec"
|
||||
)
|
||||
|
||||
// get a random duration between min and max
|
||||
func getRandomDuration(min, max time.Duration) time.Duration {
|
||||
return min + time.Duration(rand.Int63n(int64(max-min)))
|
||||
}
|
||||
|
||||
var _ = Describe("Random Duration Generator", func() {
|
||||
It("gets a random RTT", func() {
|
||||
var min time.Duration = time.Hour
|
||||
var max time.Duration
|
||||
|
||||
var sum time.Duration
|
||||
rep := 10000
|
||||
for i := 0; i < rep; i++ {
|
||||
val := getRandomDuration(100*time.Millisecond, 500*time.Millisecond)
|
||||
sum += val
|
||||
if val < min {
|
||||
min = val
|
||||
}
|
||||
if val > max {
|
||||
max = val
|
||||
}
|
||||
}
|
||||
avg := sum / time.Duration(rep)
|
||||
Expect(avg).To(BeNumerically("~", 300*time.Millisecond, 5*time.Millisecond))
|
||||
Expect(min).To(BeNumerically(">=", 100*time.Millisecond))
|
||||
Expect(min).To(BeNumerically("<", 105*time.Millisecond))
|
||||
Expect(max).To(BeNumerically(">", 495*time.Millisecond))
|
||||
Expect(max).To(BeNumerically("<=", 500*time.Millisecond))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Random RTT", func() {
|
||||
var proxy *quicproxy.QuicProxy
|
||||
|
||||
runRTTTest := func(minRtt, maxRtt time.Duration, version protocol.VersionNumber) {
|
||||
var err error
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
|
||||
RemoteAddr: "localhost:" + testserver.Port(),
|
||||
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
|
||||
return getRandomDuration(minRtt, maxRtt)
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
||||
"https://quic.clemente.io/prdata",
|
||||
)
|
||||
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 20).Should(Exit(0))
|
||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
||||
}
|
||||
|
||||
AfterEach(func() {
|
||||
err := proxy.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
time.Sleep(time.Millisecond)
|
||||
})
|
||||
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
It("gets a file a random RTT between 10ms and 30ms", func() {
|
||||
runRTTTest(10*time.Millisecond, 30*time.Millisecond, version)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,66 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
. "github.com/onsi/gomega/gexec"
|
||||
)
|
||||
|
||||
var _ = Describe("non-zero RTT", func() {
|
||||
var proxy *quicproxy.QuicProxy
|
||||
|
||||
runRTTTest := func(rtt time.Duration, version protocol.VersionNumber) {
|
||||
var err error
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
|
||||
RemoteAddr: "localhost:" + testserver.Port(),
|
||||
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
|
||||
return rtt / 2
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
||||
"https://quic.clemente.io/prdata",
|
||||
)
|
||||
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 20).Should(Exit(0))
|
||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
||||
}
|
||||
|
||||
AfterEach(func() {
|
||||
err := proxy.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
time.Sleep(time.Millisecond)
|
||||
})
|
||||
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
roundTrips := [...]int{10, 50, 100, 200}
|
||||
for _, rtt := range roundTrips {
|
||||
It(fmt.Sprintf("gets a 500kB file with %dms RTT", rtt), func() {
|
||||
runRTTTest(time.Duration(rtt)*time.Millisecond, version)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,218 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
mrand "math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gbytes"
|
||||
. "github.com/onsi/gomega/gexec"
|
||||
)
|
||||
|
||||
var _ = Describe("Server tests", func() {
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
var (
|
||||
serverPort string
|
||||
tmpDir string
|
||||
session *Session
|
||||
client *http.Client
|
||||
)
|
||||
|
||||
generateCA := func() (*rsa.PrivateKey, *x509.Certificate) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
templateRoot := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
IsCA: true,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, templateRoot, templateRoot, &key.PublicKey, key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return key, cert
|
||||
}
|
||||
|
||||
// prepare the file such that it can be by the quic_server
|
||||
// some HTTP headers neeed to be prepended, see https://www.chromium.org/quic/playing-with-quic
|
||||
createDownloadFile := func(filename string, data []byte) {
|
||||
dataDir := filepath.Join(tmpDir, "quic.clemente.io")
|
||||
err := os.Mkdir(dataDir, 0777)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
f, err := os.Create(filepath.Join(dataDir, filename))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer f.Close()
|
||||
_, err = f.Write([]byte("HTTP/1.1 200 OK\n"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = f.Write([]byte("Content-Type: text/html\n"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = f.Write([]byte("X-Original-Url: https://quic.clemente.io:" + serverPort + "/" + filename + "\n"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = f.Write([]byte("Content-Length: " + strconv.Itoa(len(data)) + "\n\n"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = f.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
// download files must be create *before* the quic_server is started
|
||||
// the quic_server reads its data dir on startup, and only serves those files that were already present then
|
||||
startServer := func(version protocol.VersionNumber) {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
command := exec.Command(
|
||||
serverPath,
|
||||
"--quic_response_cache_dir="+filepath.Join(tmpDir, "quic.clemente.io"),
|
||||
"--key_file="+filepath.Join(tmpDir, "key.pkcs8"),
|
||||
"--certificate_file="+filepath.Join(tmpDir, "cert.pem"),
|
||||
"--quic-version="+strconv.Itoa(int(version)),
|
||||
"--port="+serverPort,
|
||||
)
|
||||
session, err = Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
|
||||
stopServer := func() {
|
||||
session.Kill()
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
serverPort = strconv.Itoa(20000 + int(mrand.Int31n(10000)))
|
||||
|
||||
var err error
|
||||
tmpDir, err = ioutil.TempDir("", "quic-server-certs")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// generate an RSA key pair for the server
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// save the private key in PKCS8 format to disk (required by quic_server)
|
||||
pkcs8key, err := asn1.Marshal(struct { // copied from the x509 package
|
||||
Version int
|
||||
Algo pkix.AlgorithmIdentifier
|
||||
PrivateKey []byte
|
||||
}{
|
||||
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
|
||||
Algo: pkix.AlgorithmIdentifier{
|
||||
Algorithm: asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1},
|
||||
Parameters: asn1.RawValue{Tag: 5},
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
f, err := os.Create(filepath.Join(tmpDir, "key.pkcs8"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = f.Write(pkcs8key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
f.Close()
|
||||
|
||||
// generate a Certificate Authority
|
||||
// this CA is used to sign the server's key
|
||||
// it is set as a valid CA in the QUIC client
|
||||
rootKey, CACert := generateCA()
|
||||
// generate the server certificate
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-30 * time.Minute),
|
||||
NotAfter: time.Now().Add(30 * time.Minute),
|
||||
Subject: pkix.Name{CommonName: "quic.clemente.io"},
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, CACert, &key.PublicKey, rootKey)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// save the certificate to disk
|
||||
certOut, err := os.Create(filepath.Join(tmpDir, "cert.pem"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
certOut.Close()
|
||||
|
||||
// prepare the h2quic.client
|
||||
certPool := x509.NewCertPool()
|
||||
certPool.AddCert(CACert)
|
||||
client = &http.Client{
|
||||
Transport: &h2quic.RoundTripper{
|
||||
TLSClientConfig: &tls.Config{RootCAs: certPool},
|
||||
QuicConfig: &quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(tmpDir).ToNot(BeEmpty())
|
||||
err := os.RemoveAll(tmpDir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tmpDir = ""
|
||||
})
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
It("downloads a hello", func() {
|
||||
data := []byte("Hello world!\n")
|
||||
createDownloadFile("hello", data)
|
||||
|
||||
startServer(version)
|
||||
defer stopServer()
|
||||
|
||||
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/hello")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.StatusCode).To(Equal(200))
|
||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(body).To(Equal(data))
|
||||
})
|
||||
|
||||
It("downloads a small file", func() {
|
||||
createDownloadFile("file.dat", testserver.PRData)
|
||||
|
||||
startServer(version)
|
||||
defer stopServer()
|
||||
|
||||
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.StatusCode).To(Equal(200))
|
||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(body).To(Equal(testserver.PRData))
|
||||
})
|
||||
|
||||
It("downloads a large file", func() {
|
||||
createDownloadFile("file.dat", testserver.PRDataLong)
|
||||
|
||||
startServer(version)
|
||||
defer stopServer()
|
||||
|
||||
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.StatusCode).To(Equal(200))
|
||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 20*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(body).To(Equal(testserver.PRDataLong))
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,97 @@
|
|||
package self_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gbytes"
|
||||
)
|
||||
|
||||
var _ = Describe("Client tests", func() {
|
||||
var client *http.Client
|
||||
|
||||
// also run some tests with the TLS handshake
|
||||
versions := append(protocol.SupportedVersions, protocol.VersionTLS)
|
||||
|
||||
BeforeEach(func() {
|
||||
err := os.Setenv("HOSTALIASES", "quic.clemente.io 127.0.0.1")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
addr, err := net.ResolveUDPAddr("udp4", "quic.clemente.io:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
if addr.String() != "127.0.0.1:0" {
|
||||
Fail("quic.clemente.io does not resolve to 127.0.0.1. Consider adding it to /etc/hosts.")
|
||||
}
|
||||
testserver.StartQuicServer(versions)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
testserver.StopQuicServer()
|
||||
})
|
||||
|
||||
for _, v := range versions {
|
||||
version := v
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
BeforeEach(func() {
|
||||
client = &http.Client{
|
||||
Transport: &h2quic.RoundTripper{
|
||||
QuicConfig: &quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
It("downloads a hello", func() {
|
||||
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/hello")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal("Hello, World!\n"))
|
||||
})
|
||||
|
||||
It("downloads a small file", func() {
|
||||
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/prdata")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(body).To(Equal(testserver.PRData))
|
||||
})
|
||||
|
||||
It("downloads a large file", func() {
|
||||
resp, err := client.Get("https://quic.clemente.io:" + testserver.Port() + "/prdatalong")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 20*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(body).To(Equal(testserver.PRDataLong))
|
||||
})
|
||||
|
||||
It("uploads a file", func() {
|
||||
resp, err := client.Post(
|
||||
"https://quic.clemente.io:"+testserver.Port()+"/echo",
|
||||
"text/plain",
|
||||
bytes.NewReader(testserver.PRData),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(bytes.Equal(body, testserver.PRData)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,101 @@
|
|||
package self_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Connection ID lengths tests", func() {
|
||||
randomConnIDLen := func() int {
|
||||
return 4 + int(rand.Int31n(15))
|
||||
}
|
||||
|
||||
runServer := func(conf *quic.Config) quic.Listener {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength)))
|
||||
ln, err := quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), conf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
sess, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer str.Close()
|
||||
_, err = str.Write(testserver.PRData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
}
|
||||
}()
|
||||
return ln
|
||||
}
|
||||
|
||||
runClient := func(addr net.Addr, conf *quic.Config) {
|
||||
GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", conf.ConnectionIDLength)))
|
||||
cl, err := quic.DialAddr(
|
||||
fmt.Sprintf("quic.clemente.io:%d", addr.(*net.UDPAddr).Port),
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
conf,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer cl.Close()
|
||||
str, err := cl.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(testserver.PRData))
|
||||
}
|
||||
|
||||
Context("IETF QUIC", func() {
|
||||
It("downloads a file using a 0-byte connection ID for the client", func() {
|
||||
serverConf := &quic.Config{
|
||||
ConnectionIDLength: randomConnIDLen(),
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
clientConf := &quic.Config{
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
runClient(ln.Addr(), clientConf)
|
||||
})
|
||||
|
||||
It("downloads a file when both client and server use a random connection ID length", func() {
|
||||
serverConf := &quic.Config{
|
||||
ConnectionIDLength: randomConnIDLen(),
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
clientConf := &quic.Config{
|
||||
ConnectionIDLength: randomConnIDLen(),
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
runClient(ln.Addr(), clientConf)
|
||||
})
|
||||
})
|
||||
|
||||
Context("gQUIC", func() {
|
||||
It("downloads a file using a 0-byte connection ID for the client", func() {
|
||||
ln := runServer(&quic.Config{})
|
||||
defer ln.Close()
|
||||
runClient(ln.Addr(), &quic.Config{RequestConnectionIDOmission: true})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,189 @@
|
|||
package self_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gbytes"
|
||||
)
|
||||
|
||||
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
|
||||
|
||||
type applicationProtocol struct {
|
||||
name string
|
||||
run func(protocol.VersionNumber)
|
||||
}
|
||||
|
||||
var _ = Describe("Handshake drop tests", func() {
|
||||
var (
|
||||
proxy *quicproxy.QuicProxy
|
||||
ln quic.Listener
|
||||
)
|
||||
|
||||
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) {
|
||||
var err error
|
||||
ln, err = quic.ListenAddr(
|
||||
"localhost:0",
|
||||
testdata.GetTLSConfig(),
|
||||
&quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DropPacket: dropCallback,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
stochasticDropper := func(freq int) bool {
|
||||
return mrand.Int63n(int64(freq)) == 0
|
||||
}
|
||||
|
||||
clientSpeaksFirst := &applicationProtocol{
|
||||
name: "client speaks first",
|
||||
run: func(version protocol.VersionNumber) {
|
||||
serverSessionChan := make(chan quic.Session)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer sess.Close()
|
||||
str, err := sess.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 6)
|
||||
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(b)).To(Equal("foobar"))
|
||||
serverSessionChan <- sess
|
||||
}()
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
|
||||
nil,
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var serverSession quic.Session
|
||||
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
|
||||
sess.Close()
|
||||
serverSession.Close()
|
||||
},
|
||||
}
|
||||
|
||||
serverSpeaksFirst := &applicationProtocol{
|
||||
name: "server speaks first",
|
||||
run: func(version protocol.VersionNumber) {
|
||||
serverSessionChan := make(chan quic.Session)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverSessionChan <- sess
|
||||
}()
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
|
||||
nil,
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b := make([]byte, 6)
|
||||
_, err = gbytes.TimeoutReader(str, 10*time.Second).Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(b)).To(Equal("foobar"))
|
||||
|
||||
var serverSession quic.Session
|
||||
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
|
||||
sess.Close()
|
||||
serverSession.Close()
|
||||
},
|
||||
}
|
||||
|
||||
nobodySpeaks := &applicationProtocol{
|
||||
name: "nobody speaks",
|
||||
run: func(version protocol.VersionNumber) {
|
||||
serverSessionChan := make(chan quic.Session)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverSessionChan <- sess
|
||||
}()
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
|
||||
nil,
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var serverSession quic.Session
|
||||
Eventually(serverSessionChan, 10*time.Second).Should(Receive(&serverSession))
|
||||
// both server and client accepted a session. Close now.
|
||||
sess.Close()
|
||||
serverSession.Close()
|
||||
},
|
||||
}
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
|
||||
version := v
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
for _, d := range directions {
|
||||
direction := d
|
||||
|
||||
for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} {
|
||||
app := a
|
||||
|
||||
Context(app.name, func() {
|
||||
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() {
|
||||
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p == 1 && d.Is(direction)
|
||||
}, version)
|
||||
app.run(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() {
|
||||
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p == 2 && d.Is(direction)
|
||||
}, version)
|
||||
app.run(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() {
|
||||
startListenerAndProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return d.Is(direction) && stochasticDropper(5)
|
||||
}, version)
|
||||
app.run(version)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,213 @@
|
|||
package self_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Handshake RTT tests", func() {
|
||||
var (
|
||||
proxy *quicproxy.QuicProxy
|
||||
server quic.Listener
|
||||
serverConfig *quic.Config
|
||||
testStartedAt time.Time
|
||||
acceptStopped chan struct{}
|
||||
)
|
||||
|
||||
rtt := 400 * time.Millisecond
|
||||
|
||||
BeforeEach(func() {
|
||||
acceptStopped = make(chan struct{})
|
||||
serverConfig = &quic.Config{}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
Expect(server.Close()).To(Succeed())
|
||||
<-acceptStopped
|
||||
})
|
||||
|
||||
runServerAndProxy := func() {
|
||||
var err error
|
||||
// start the server
|
||||
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// start the proxy
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", protocol.VersionWhatever, &quicproxy.Opts{
|
||||
RemoteAddr: server.Addr().String(),
|
||||
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration { return rtt / 2 },
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
testStartedAt = time.Now()
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(acceptStopped)
|
||||
for {
|
||||
_, err := server.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
expectDurationInRTTs := func(num int) {
|
||||
testDuration := time.Since(testStartedAt)
|
||||
rtts := float32(testDuration) / float32(rtt)
|
||||
Expect(rtts).To(SatisfyAll(
|
||||
BeNumerically(">=", num),
|
||||
BeNumerically("<", num+1),
|
||||
))
|
||||
}
|
||||
|
||||
It("fails when there's no matching version, after 1 RTT", func() {
|
||||
if len(protocol.SupportedVersions) == 1 {
|
||||
Skip("Test requires at least 2 supported versions.")
|
||||
}
|
||||
serverConfig.Versions = protocol.SupportedVersions[:1]
|
||||
runServerAndProxy()
|
||||
clientConfig := &quic.Config{
|
||||
Versions: protocol.SupportedVersions[1:2],
|
||||
}
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
|
||||
expectDurationInRTTs(1)
|
||||
})
|
||||
|
||||
Context("gQUIC", func() {
|
||||
// 1 RTT for verifying the source address
|
||||
// 1 RTT to become secure
|
||||
// 1 RTT to become forward-secure
|
||||
It("is forward-secure after 3 RTTs", func() {
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(3)
|
||||
})
|
||||
|
||||
It("does version negotiation in 1 RTT, IETF QUIC => gQUIC", func() {
|
||||
clientConfig := &quic.Config{
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS, protocol.SupportedVersions[0]},
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(4)
|
||||
})
|
||||
|
||||
It("is forward-secure after 2 RTTs when the server doesn't require a Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return true
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(2)
|
||||
})
|
||||
|
||||
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return false
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
|
||||
})
|
||||
|
||||
It("doesn't complete the handshake when the handshake timeout is too short", func() {
|
||||
serverConfig.HandshakeTimeout = 2 * rtt
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
|
||||
// 2 RTTs during the timeout
|
||||
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
|
||||
expectDurationInRTTs(3)
|
||||
})
|
||||
})
|
||||
|
||||
Context("IETF QUIC", func() {
|
||||
var clientConfig *quic.Config
|
||||
var clientTLSConfig *tls.Config
|
||||
|
||||
BeforeEach(func() {
|
||||
serverConfig.Versions = []protocol.VersionNumber{protocol.VersionTLS}
|
||||
clientConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
clientTLSConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "quic.clemente.io",
|
||||
}
|
||||
})
|
||||
|
||||
// 1 RTT for verifying the source address
|
||||
// 1 RTT for the TLS handshake
|
||||
It("is forward-secure after 2 RTTs", func() {
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(2)
|
||||
})
|
||||
|
||||
It("does version negotiation in 1 RTT, gQUIC => IETF QUIC", func() {
|
||||
clientConfig.Versions = []protocol.VersionNumber{protocol.SupportedVersions[0], protocol.VersionTLS}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(3)
|
||||
})
|
||||
|
||||
It("is forward-secure after 1 RTTs when the server doesn't require a Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return true
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(1)
|
||||
})
|
||||
|
||||
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return false
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.CryptoTooManyRejects))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,128 @@
|
|||
package self_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type versioner interface {
|
||||
GetVersion() protocol.VersionNumber
|
||||
}
|
||||
|
||||
var _ = Describe("Handshake tests", func() {
|
||||
var (
|
||||
server quic.Listener
|
||||
serverConfig *quic.Config
|
||||
acceptStopped chan struct{}
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
server = nil
|
||||
acceptStopped = make(chan struct{})
|
||||
serverConfig = &quic.Config{}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
if server != nil {
|
||||
server.Close()
|
||||
<-acceptStopped
|
||||
}
|
||||
})
|
||||
|
||||
runServer := func() {
|
||||
var err error
|
||||
// start the server
|
||||
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer close(acceptStopped)
|
||||
for {
|
||||
_, err := server.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
Context("Version Negotiation", func() {
|
||||
var supportedVersions []protocol.VersionNumber
|
||||
|
||||
BeforeEach(func() {
|
||||
supportedVersions = protocol.SupportedVersions
|
||||
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
protocol.SupportedVersions = supportedVersions
|
||||
})
|
||||
|
||||
It("when the server supports more versions than the client", func() {
|
||||
// the server doesn't support the highest supported version, which is the first one the client will try
|
||||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverConfig.Versions = []protocol.VersionNumber{protocol.SupportedVersions[1], 7, 8, 9}
|
||||
runServer()
|
||||
sess, err := quic.DialAddr(server.Addr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[1]))
|
||||
})
|
||||
|
||||
It("when the client supports more versions than the server supports", func() {
|
||||
// the server doesn't support the highest supported version, which is the first one the client will try
|
||||
// but it supports a bunch of versions that the client doesn't speak
|
||||
serverConfig.Versions = supportedVersions
|
||||
runServer()
|
||||
conf := &quic.Config{
|
||||
Versions: []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[1], 10},
|
||||
}
|
||||
sess, err := quic.DialAddr(server.Addr().String(), &tls.Config{InsecureSkipVerify: true}, conf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[1]))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Certifiate validation", func() {
|
||||
for _, v := range []protocol.VersionNumber{protocol.Version39, protocol.VersionTLS} {
|
||||
version := v
|
||||
|
||||
Context(fmt.Sprintf("using %s", version), func() {
|
||||
var clientConfig *quic.Config
|
||||
|
||||
BeforeEach(func() {
|
||||
serverConfig.Versions = []protocol.VersionNumber{version}
|
||||
clientConfig = &quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
}
|
||||
})
|
||||
|
||||
It("accepts the certificate", func() {
|
||||
runServer()
|
||||
_, err := quic.DialAddr(fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port), nil, clientConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors if the server name doesn't match", func() {
|
||||
runServer()
|
||||
_, err := quic.DialAddr(fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port), nil, clientConfig)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("uses the ServerName in the tls.Config", func() {
|
||||
runServer()
|
||||
conf := &tls.Config{ServerName: "quic.clemente.io"}
|
||||
_, err := quic.DialAddr(fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port), conf, clientConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
|
@ -0,0 +1,232 @@
|
|||
package self_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Multiplexing", func() {
|
||||
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
|
||||
version := v
|
||||
|
||||
// gQUIC 44 uses 0 byte connection IDs for packets sent to the client
|
||||
// It's not possible to do demultiplexing.
|
||||
if v == protocol.Version44 {
|
||||
continue
|
||||
}
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
runServer := func(ln quic.Listener) {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
for {
|
||||
sess, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer str.Close()
|
||||
_, err = str.Write(testserver.PRDataLong)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
dial := func(conn net.PacketConn, addr net.Addr) {
|
||||
sess, err := quic.Dial(
|
||||
conn,
|
||||
addr,
|
||||
fmt.Sprintf("quic.clemente.io:%d", addr.(*net.UDPAddr).Port),
|
||||
nil,
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(testserver.PRDataLong))
|
||||
}
|
||||
|
||||
Context("multiplexing clients on the same conn", func() {
|
||||
getListener := func() quic.Listener {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
testdata.GetTLSConfig(),
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return ln
|
||||
}
|
||||
|
||||
It("multiplexes connections to the same server", func() {
|
||||
server := getListener()
|
||||
runServer(server)
|
||||
defer server.Close()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
if testlog.Debug() {
|
||||
timeout = time.Minute
|
||||
}
|
||||
Eventually(done1, timeout).Should(BeClosed())
|
||||
Eventually(done2, timeout).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("multiplexes connections to different servers", func() {
|
||||
server1 := getListener()
|
||||
runServer(server1)
|
||||
defer server1.Close()
|
||||
server2 := getListener()
|
||||
runServer(server2)
|
||||
defer server2.Close()
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server1.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server2.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
if testlog.Debug() {
|
||||
timeout = time.Minute
|
||||
}
|
||||
Eventually(done1, timeout).Should(BeClosed())
|
||||
Eventually(done2, timeout).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("multiplexing server and client on the same conn", func() {
|
||||
It("connects to itself", func() {
|
||||
if version != protocol.VersionTLS {
|
||||
Skip("Connecting to itself only works with IETF QUIC.")
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn.Close()
|
||||
|
||||
server, err := quic.Listen(
|
||||
conn,
|
||||
testdata.GetTLSConfig(),
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runServer(server)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn, server.Addr())
|
||||
close(done)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
if testlog.Debug() {
|
||||
timeout = time.Minute
|
||||
}
|
||||
Eventually(done, timeout).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("runs a server and client on the same conn", func() {
|
||||
if os.Getenv("CI") == "true" {
|
||||
Skip("This test is flaky on CIs, see see https://github.com/golang/go/issues/17677.")
|
||||
}
|
||||
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn1, err := net.ListenUDP("udp", addr1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn1.Close()
|
||||
|
||||
addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn2, err := net.ListenUDP("udp", addr2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer conn2.Close()
|
||||
|
||||
server1, err := quic.Listen(
|
||||
conn1,
|
||||
testdata.GetTLSConfig(),
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runServer(server1)
|
||||
defer server1.Close()
|
||||
|
||||
server2, err := quic.Listen(
|
||||
conn2,
|
||||
testdata.GetTLSConfig(),
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runServer(server2)
|
||||
defer server2.Close()
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn2, server1.Addr())
|
||||
close(done1)
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
dial(conn1, server2.Addr())
|
||||
close(done2)
|
||||
}()
|
||||
timeout := 30 * time.Second
|
||||
if testlog.Debug() {
|
||||
timeout = time.Minute
|
||||
}
|
||||
Eventually(done1, timeout).Should(BeClosed())
|
||||
Eventually(done2, timeout).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,83 @@
|
|||
package self
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("non-zero RTT", func() {
|
||||
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
|
||||
version := v
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
roundTrips := [...]time.Duration{
|
||||
10 * time.Millisecond,
|
||||
50 * time.Millisecond,
|
||||
100 * time.Millisecond,
|
||||
200 * time.Millisecond,
|
||||
}
|
||||
|
||||
for _, r := range roundTrips {
|
||||
rtt := r
|
||||
|
||||
It(fmt.Sprintf("downloads a message with %s RTT", rtt), func() {
|
||||
ln, err := quic.ListenAddr(
|
||||
"localhost:0",
|
||||
testdata.GetTLSConfig(),
|
||||
&quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := ln.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.OpenStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(testserver.PRData)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str.Close()
|
||||
close(done)
|
||||
}()
|
||||
serverPort := ln.Addr().(*net.UDPAddr).Port
|
||||
proxy, err := quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
|
||||
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
|
||||
DelayPacket: func(d quicproxy.Direction, p uint64) time.Duration {
|
||||
return rtt / 2
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer proxy.Close()
|
||||
|
||||
sess, err := quic.DialAddr(
|
||||
fmt.Sprintf("quic.clemente.io:%d", proxy.LocalPort()),
|
||||
nil,
|
||||
&quic.Config{Versions: []protocol.VersionNumber{version}},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
str, err := sess.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(testserver.PRData))
|
||||
sess.Close()
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,20 @@
|
|||
package self_test
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
||||
)
|
||||
|
||||
func TestSelf(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Self integration tests")
|
||||
}
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
rand.Seed(GinkgoRandomSeed())
|
||||
})
|
|
@ -0,0 +1,152 @@
|
|||
package self_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Bidirectional streams", func() {
|
||||
const numStreams = 300
|
||||
|
||||
var (
|
||||
server quic.Listener
|
||||
serverAddr string
|
||||
qconf *quic.Config
|
||||
)
|
||||
|
||||
for _, v := range []protocol.VersionNumber{protocol.VersionTLS} {
|
||||
version := v
|
||||
|
||||
Context(fmt.Sprintf("with QUIC %s", version), func() {
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
qconf = &quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
MaxIncomingStreams: 0,
|
||||
}
|
||||
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), qconf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAddr = fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
server.Close()
|
||||
})
|
||||
|
||||
runSendingPeer := func(sess quic.Session) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := sess.OpenStreamSync()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data := testserver.GeneratePRData(25 * i)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := str.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
dataRead, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(dataRead).To(Equal(data))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
runReceivingPeer := func(sess quic.Session) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := sess.AcceptStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
// shouldn't use io.Copy here
|
||||
// we should read from the stream as early as possible, to free flow control credit
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = str.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
|
||||
var sess quic.Session
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
sess, err = server.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(sess)
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(client)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(sess)
|
||||
sess.Close()
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(client)
|
||||
Eventually(client.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("client and server opening %d each and sending data to the peer", numStreams), func() {
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
runReceivingPeer(sess)
|
||||
close(done)
|
||||
}()
|
||||
runSendingPeer(sess)
|
||||
<-done
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
runSendingPeer(client)
|
||||
close(done2)
|
||||
}()
|
||||
runReceivingPeer(client)
|
||||
<-done1
|
||||
<-done2
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,132 @@
|
|||
package self_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Unidirectional Streams", func() {
|
||||
const numStreams = 500
|
||||
|
||||
var (
|
||||
server quic.Listener
|
||||
serverAddr string
|
||||
qconf *quic.Config
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
qconf = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), qconf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAddr = fmt.Sprintf("quic.clemente.io:%d", server.Addr().(*net.UDPAddr).Port)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
server.Close()
|
||||
})
|
||||
|
||||
dataForStream := func(id protocol.StreamID) []byte {
|
||||
return testserver.GeneratePRData(10 * int(id))
|
||||
}
|
||||
|
||||
runSendingPeer := func(sess quic.Session) {
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := sess.OpenUniStreamSync()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
_, err := str.Write(dataForStream(str.StreamID()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(str.Close()).To(Succeed())
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
runReceivingPeer := func(sess quic.Session) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
str, err := sess.AcceptUniStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer wg.Done()
|
||||
data, err := ioutil.ReadAll(str)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(data).To(Equal(dataForStream(str.StreamID())))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() {
|
||||
var sess quic.Session
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
sess, err = server.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(sess)
|
||||
sess.Close()
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(client)
|
||||
<-client.Context().Done()
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runSendingPeer(sess)
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
runReceivingPeer(client)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("client and server opening %d streams each and sending data to the peer", numStreams), func() {
|
||||
done1 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess, err := server.Accept()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
runReceivingPeer(sess)
|
||||
close(done)
|
||||
}()
|
||||
runSendingPeer(sess)
|
||||
<-done
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
client, err := quic.DialAddr(serverAddr, nil, qconf)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done2 := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
runSendingPeer(client)
|
||||
close(done2)
|
||||
}()
|
||||
runReceivingPeer(client)
|
||||
<-done1
|
||||
<-done2
|
||||
})
|
||||
})
|
|
@ -0,0 +1,270 @@
|
|||
package quicproxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// Connection is a UDP connection
|
||||
type connection struct {
|
||||
ClientAddr *net.UDPAddr // Address of the client
|
||||
ServerConn *net.UDPConn // UDP connection to server
|
||||
|
||||
incomingPacketCounter uint64
|
||||
outgoingPacketCounter uint64
|
||||
}
|
||||
|
||||
// Direction is the direction a packet is sent.
|
||||
type Direction int
|
||||
|
||||
const (
|
||||
// DirectionIncoming is the direction from the client to the server.
|
||||
DirectionIncoming Direction = iota
|
||||
// DirectionOutgoing is the direction from the server to the client.
|
||||
DirectionOutgoing
|
||||
// DirectionBoth is both incoming and outgoing
|
||||
DirectionBoth
|
||||
)
|
||||
|
||||
func (d Direction) String() string {
|
||||
switch d {
|
||||
case DirectionIncoming:
|
||||
return "incoming"
|
||||
case DirectionOutgoing:
|
||||
return "outgoing"
|
||||
case DirectionBoth:
|
||||
return "both"
|
||||
default:
|
||||
panic("unknown direction")
|
||||
}
|
||||
}
|
||||
|
||||
// Is says if one direction matches another direction.
|
||||
// For example, incoming matches both incoming and both, but not outgoing.
|
||||
func (d Direction) Is(dir Direction) bool {
|
||||
if d == DirectionBoth || dir == DirectionBoth {
|
||||
return true
|
||||
}
|
||||
return d == dir
|
||||
}
|
||||
|
||||
// DropCallback is a callback that determines which packet gets dropped.
|
||||
type DropCallback func(dir Direction, packetCount uint64) bool
|
||||
|
||||
// NoDropper doesn't drop packets.
|
||||
var NoDropper DropCallback = func(Direction, uint64) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// DelayCallback is a callback that determines how much delay to apply to a packet.
|
||||
type DelayCallback func(dir Direction, packetCount uint64) time.Duration
|
||||
|
||||
// NoDelay doesn't apply a delay.
|
||||
var NoDelay DelayCallback = func(Direction, uint64) time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Opts are proxy options.
|
||||
type Opts struct {
|
||||
// The address this proxy proxies packets to.
|
||||
RemoteAddr string
|
||||
// DropPacket determines whether a packet gets dropped.
|
||||
DropPacket DropCallback
|
||||
// DelayPacket determines how long a packet gets delayed. This allows
|
||||
// simulating a connection with non-zero RTTs.
|
||||
// Note that the RTT is the sum of the delay for the incoming and the outgoing packet.
|
||||
DelayPacket DelayCallback
|
||||
}
|
||||
|
||||
// QuicProxy is a QUIC proxy that can drop and delay packets.
|
||||
type QuicProxy struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
version protocol.VersionNumber
|
||||
|
||||
conn *net.UDPConn
|
||||
serverAddr *net.UDPAddr
|
||||
|
||||
dropPacket DropCallback
|
||||
delayPacket DelayCallback
|
||||
|
||||
// Mapping from client addresses (as host:port) to connection
|
||||
clientDict map[string]*connection
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// NewQuicProxy creates a new UDP proxy
|
||||
func NewQuicProxy(local string, version protocol.VersionNumber, opts *Opts) (*QuicProxy, error) {
|
||||
if opts == nil {
|
||||
opts = &Opts{}
|
||||
}
|
||||
laddr, err := net.ResolveUDPAddr("udp", local)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
packetDropper := NoDropper
|
||||
if opts.DropPacket != nil {
|
||||
packetDropper = opts.DropPacket
|
||||
}
|
||||
|
||||
packetDelayer := NoDelay
|
||||
if opts.DelayPacket != nil {
|
||||
packetDelayer = opts.DelayPacket
|
||||
}
|
||||
|
||||
p := QuicProxy{
|
||||
clientDict: make(map[string]*connection),
|
||||
conn: conn,
|
||||
serverAddr: raddr,
|
||||
dropPacket: packetDropper,
|
||||
delayPacket: packetDelayer,
|
||||
version: version,
|
||||
logger: utils.DefaultLogger.WithPrefix("proxy"),
|
||||
}
|
||||
|
||||
p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
|
||||
go p.runProxy()
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Close stops the UDP Proxy
|
||||
func (p *QuicProxy) Close() error {
|
||||
p.mutex.Lock()
|
||||
defer p.mutex.Unlock()
|
||||
for _, c := range p.clientDict {
|
||||
if err := c.ServerConn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return p.conn.Close()
|
||||
}
|
||||
|
||||
// LocalAddr is the address the proxy is listening on.
|
||||
func (p *QuicProxy) LocalAddr() net.Addr {
|
||||
return p.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// LocalPort is the UDP port number the proxy is listening on.
|
||||
func (p *QuicProxy) LocalPort() int {
|
||||
return p.conn.LocalAddr().(*net.UDPAddr).Port
|
||||
}
|
||||
|
||||
func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
|
||||
srvudp, err := net.DialUDP("udp", nil, p.serverAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &connection{
|
||||
ClientAddr: cliAddr,
|
||||
ServerConn: srvudp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runProxy listens on the proxy address and handles incoming packets.
|
||||
func (p *QuicProxy) runProxy() error {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||
n, cliaddr, err := p.conn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
|
||||
saddr := cliaddr.String()
|
||||
p.mutex.Lock()
|
||||
conn, ok := p.clientDict[saddr]
|
||||
|
||||
if !ok {
|
||||
conn, err = p.newConnection(cliaddr)
|
||||
if err != nil {
|
||||
p.mutex.Unlock()
|
||||
return err
|
||||
}
|
||||
p.clientDict[saddr] = conn
|
||||
go p.runConnection(conn)
|
||||
}
|
||||
p.mutex.Unlock()
|
||||
|
||||
packetCount := atomic.AddUint64(&conn.incomingPacketCounter, 1)
|
||||
|
||||
if p.dropPacket(DirectionIncoming, packetCount) {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("dropping incoming packet %d (%d bytes)", packetCount, n)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the packet to the server
|
||||
delay := p.delayPacket(DirectionIncoming, packetCount)
|
||||
if delay != 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying incoming packet %d (%d bytes) to %s by %s", packetCount, n, conn.ServerConn.RemoteAddr(), delay)
|
||||
}
|
||||
time.AfterFunc(delay, func() {
|
||||
// TODO: handle error
|
||||
_, _ = conn.ServerConn.Write(raw)
|
||||
})
|
||||
} else {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding incoming packet %d (%d bytes) to %s", packetCount, n, conn.ServerConn.RemoteAddr())
|
||||
}
|
||||
if _, err := conn.ServerConn.Write(raw); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runConnection handles packets from server to a single client
|
||||
func (p *QuicProxy) runConnection(conn *connection) error {
|
||||
for {
|
||||
buffer := make([]byte, protocol.MaxReceivePacketSize)
|
||||
n, err := conn.ServerConn.Read(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
raw := buffer[0:n]
|
||||
|
||||
packetCount := atomic.AddUint64(&conn.outgoingPacketCounter, 1)
|
||||
|
||||
if p.dropPacket(DirectionOutgoing, packetCount) {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("dropping outgoing packet %d (%d bytes)", packetCount, n)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
delay := p.delayPacket(DirectionOutgoing, packetCount)
|
||||
if delay != 0 {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("delaying outgoing packet %d (%d bytes) to %s by %s", packetCount, n, conn.ClientAddr, delay)
|
||||
}
|
||||
time.AfterFunc(delay, func() {
|
||||
// TODO: handle error
|
||||
_, _ = p.conn.WriteToUDP(raw, conn.ClientAddr)
|
||||
})
|
||||
} else {
|
||||
if p.logger.Debug() {
|
||||
p.logger.Debugf("forwarding outgoing packet %d (%d bytes) to %s", packetCount, n, conn.ClientAddr)
|
||||
}
|
||||
if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
13
vendor/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy_suite_test.go
vendored
Normal file
13
vendor/lucas-clemente/quic-go/integrationtests/tools/proxy/proxy_suite_test.go
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
package quicproxy
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQuicGo(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "QUIC Proxy")
|
||||
}
|
|
@ -0,0 +1,394 @@
|
|||
package quicproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"runtime/pprof"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type packetData []byte
|
||||
|
||||
var _ = Describe("QUIC Proxy", func() {
|
||||
makePacket := func(p protocol.PacketNumber, payload []byte) []byte {
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
PacketNumber: p,
|
||||
PacketNumberLen: protocol.PacketNumberLen6,
|
||||
DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37},
|
||||
}
|
||||
hdr.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||
raw := b.Bytes()
|
||||
raw = append(raw, payload...)
|
||||
return raw
|
||||
}
|
||||
|
||||
Context("Proxy setup and teardown", func() {
|
||||
It("sets up the UDPProxy", func() {
|
||||
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(proxy.clientDict).To(HaveLen(0))
|
||||
|
||||
// check that the proxy port is in use
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(proxy.LocalPort()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = net.ListenUDP("udp", addr)
|
||||
Expect(err).To(MatchError(fmt.Sprintf("listen udp 127.0.0.1:%d: bind: address already in use", proxy.LocalPort())))
|
||||
Expect(proxy.Close()).To(Succeed()) // stopping is tested in the next test
|
||||
})
|
||||
|
||||
It("stops the UDPProxy", func() {
|
||||
isProxyRunning := func() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "proxy.(*QuicProxy).runProxy")
|
||||
}
|
||||
|
||||
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
port := proxy.LocalPort()
|
||||
Expect(isProxyRunning()).To(BeTrue())
|
||||
err = proxy.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// check that the proxy port is not in use anymore
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:"+strconv.Itoa(port))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// sometimes it takes a while for the OS to free the port
|
||||
Eventually(func() error {
|
||||
ln, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ln.Close()
|
||||
return nil
|
||||
}).ShouldNot(HaveOccurred())
|
||||
Eventually(isProxyRunning).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("stops listening for proxied connections", func() {
|
||||
isConnRunning := func() bool {
|
||||
var b bytes.Buffer
|
||||
pprof.Lookup("goroutine").WriteTo(&b, 1)
|
||||
return strings.Contains(b.String(), "proxy.(*QuicProxy).runConnection")
|
||||
}
|
||||
|
||||
serverAddr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConn, err := net.ListenUDP("udp", serverAddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer serverConn.Close()
|
||||
|
||||
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, &Opts{RemoteAddr: serverConn.LocalAddr().String()})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(isConnRunning()).To(BeFalse())
|
||||
|
||||
// check that the proxy port is not in use anymore
|
||||
conn, err := net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = conn.Write(makePacket(1, []byte("foobar")))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(isConnRunning).Should(BeTrue())
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
Eventually(isConnRunning).Should(BeFalse())
|
||||
})
|
||||
|
||||
It("has the correct LocalAddr and LocalPort", func() {
|
||||
proxy, err := NewQuicProxy("localhost:0", protocol.VersionWhatever, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(proxy.LocalAddr().String()).To(Equal("127.0.0.1:" + strconv.Itoa(proxy.LocalPort())))
|
||||
Expect(proxy.LocalPort()).ToNot(BeZero())
|
||||
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Proxy tests", func() {
|
||||
var (
|
||||
serverConn *net.UDPConn
|
||||
serverNumPacketsSent int32
|
||||
serverReceivedPackets chan packetData
|
||||
clientConn *net.UDPConn
|
||||
proxy *QuicProxy
|
||||
)
|
||||
|
||||
startProxy := func(opts *Opts) {
|
||||
var err error
|
||||
proxy, err = NewQuicProxy("localhost:0", protocol.VersionWhatever, opts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
clientConn, err = net.DialUDP("udp", nil, proxy.LocalAddr().(*net.UDPAddr))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
// getClientDict returns a copy of the clientDict map
|
||||
getClientDict := func() map[string]*connection {
|
||||
d := make(map[string]*connection)
|
||||
proxy.mutex.Lock()
|
||||
defer proxy.mutex.Unlock()
|
||||
for k, v := range proxy.clientDict {
|
||||
d[k] = v
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
serverReceivedPackets = make(chan packetData, 100)
|
||||
atomic.StoreInt32(&serverNumPacketsSent, 0)
|
||||
|
||||
// setup a dump UDP server
|
||||
// in production this would be a QUIC server
|
||||
raddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverConn, err = net.ListenUDP("udp", raddr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go func() {
|
||||
for {
|
||||
buf := make([]byte, protocol.MaxReceivePacketSize)
|
||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
||||
n, addr, err2 := serverConn.ReadFromUDP(buf)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
data := buf[0:n]
|
||||
serverReceivedPackets <- packetData(data)
|
||||
// echo the packet
|
||||
serverConn.WriteToUDP(data, addr)
|
||||
atomic.AddInt32(&serverNumPacketsSent, 1)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
err := proxy.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = serverConn.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = clientConn.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
})
|
||||
|
||||
Context("no packet drop", func() {
|
||||
It("relays packets from the client to the server", func() {
|
||||
startProxy(&Opts{RemoteAddr: serverConn.LocalAddr().String()})
|
||||
// send the first packet
|
||||
_, err := clientConn.Write(makePacket(1, []byte("foobar")))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Eventually(getClientDict).Should(HaveLen(1))
|
||||
var conn *connection
|
||||
for _, conn = range getClientDict() {
|
||||
Eventually(func() uint64 { return atomic.LoadUint64(&conn.incomingPacketCounter) }).Should(Equal(uint64(1)))
|
||||
}
|
||||
|
||||
// send the second packet
|
||||
_, err = clientConn.Write(makePacket(2, []byte("decafbad")))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Eventually(serverReceivedPackets).Should(HaveLen(2))
|
||||
Expect(getClientDict()).To(HaveLen(1))
|
||||
Expect(string(<-serverReceivedPackets)).To(ContainSubstring("foobar"))
|
||||
Expect(string(<-serverReceivedPackets)).To(ContainSubstring("decafbad"))
|
||||
})
|
||||
|
||||
It("relays packets from the server to the client", func() {
|
||||
startProxy(&Opts{RemoteAddr: serverConn.LocalAddr().String()})
|
||||
// send the first packet
|
||||
_, err := clientConn.Write(makePacket(1, []byte("foobar")))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Eventually(getClientDict).Should(HaveLen(1))
|
||||
var key string
|
||||
var conn *connection
|
||||
for key, conn = range getClientDict() {
|
||||
Eventually(func() uint64 { return atomic.LoadUint64(&conn.outgoingPacketCounter) }).Should(Equal(uint64(1)))
|
||||
}
|
||||
|
||||
// send the second packet
|
||||
_, err = clientConn.Write(makePacket(2, []byte("decafbad")))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(getClientDict()).To(HaveLen(1))
|
||||
Eventually(func() uint64 {
|
||||
conn := getClientDict()[key]
|
||||
return atomic.LoadUint64(&conn.outgoingPacketCounter)
|
||||
}).Should(BeEquivalentTo(2))
|
||||
|
||||
clientReceivedPackets := make(chan packetData, 2)
|
||||
// receive the packets echoed by the server on client side
|
||||
go func() {
|
||||
for {
|
||||
buf := make([]byte, protocol.MaxReceivePacketSize)
|
||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
||||
n, _, err2 := clientConn.ReadFromUDP(buf)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
data := buf[0:n]
|
||||
clientReceivedPackets <- packetData(data)
|
||||
}
|
||||
}()
|
||||
|
||||
Eventually(serverReceivedPackets).Should(HaveLen(2))
|
||||
Expect(atomic.LoadInt32(&serverNumPacketsSent)).To(BeEquivalentTo(2))
|
||||
Eventually(clientReceivedPackets).Should(HaveLen(2))
|
||||
Expect(string(<-clientReceivedPackets)).To(ContainSubstring("foobar"))
|
||||
Expect(string(<-clientReceivedPackets)).To(ContainSubstring("decafbad"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Drop Callbacks", func() {
|
||||
It("drops incoming packets", func() {
|
||||
opts := &Opts{
|
||||
RemoteAddr: serverConn.LocalAddr().String(),
|
||||
DropPacket: func(d Direction, p uint64) bool {
|
||||
return d == DirectionIncoming && p%2 == 0
|
||||
},
|
||||
}
|
||||
startProxy(opts)
|
||||
|
||||
for i := 1; i <= 6; i++ {
|
||||
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
Eventually(serverReceivedPackets).Should(HaveLen(3))
|
||||
Consistently(serverReceivedPackets).Should(HaveLen(3))
|
||||
})
|
||||
|
||||
It("drops outgoing packets", func() {
|
||||
const numPackets = 6
|
||||
opts := &Opts{
|
||||
RemoteAddr: serverConn.LocalAddr().String(),
|
||||
DropPacket: func(d Direction, p uint64) bool {
|
||||
return d == DirectionOutgoing && p%2 == 0
|
||||
},
|
||||
}
|
||||
startProxy(opts)
|
||||
|
||||
clientReceivedPackets := make(chan packetData, numPackets)
|
||||
// receive the packets echoed by the server on client side
|
||||
go func() {
|
||||
for {
|
||||
buf := make([]byte, protocol.MaxReceivePacketSize)
|
||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
||||
n, _, err2 := clientConn.ReadFromUDP(buf)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
data := buf[0:n]
|
||||
clientReceivedPackets <- packetData(data)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 1; i <= numPackets; i++ {
|
||||
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
Eventually(clientReceivedPackets).Should(HaveLen(numPackets / 2))
|
||||
Consistently(clientReceivedPackets).Should(HaveLen(numPackets / 2))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Delay Callback", func() {
|
||||
expectDelay := func(startTime time.Time, rtt time.Duration, numRTTs int) {
|
||||
expectedReceiveTime := startTime.Add(time.Duration(numRTTs) * rtt)
|
||||
Expect(time.Now()).To(SatisfyAll(
|
||||
BeTemporally(">=", expectedReceiveTime),
|
||||
BeTemporally("<", expectedReceiveTime.Add(rtt/2)),
|
||||
))
|
||||
}
|
||||
|
||||
It("delays incoming packets", func() {
|
||||
delay := 300 * time.Millisecond
|
||||
opts := &Opts{
|
||||
RemoteAddr: serverConn.LocalAddr().String(),
|
||||
// delay packet 1 by 200 ms
|
||||
// delay packet 2 by 400 ms
|
||||
// ...
|
||||
DelayPacket: func(d Direction, p uint64) time.Duration {
|
||||
if d == DirectionOutgoing {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(p) * delay
|
||||
},
|
||||
}
|
||||
startProxy(opts)
|
||||
|
||||
// send 3 packets
|
||||
start := time.Now()
|
||||
for i := 1; i <= 3; i++ {
|
||||
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
Eventually(serverReceivedPackets).Should(HaveLen(1))
|
||||
expectDelay(start, delay, 1)
|
||||
Eventually(serverReceivedPackets).Should(HaveLen(2))
|
||||
expectDelay(start, delay, 2)
|
||||
Eventually(serverReceivedPackets).Should(HaveLen(3))
|
||||
expectDelay(start, delay, 3)
|
||||
})
|
||||
|
||||
It("delays outgoing packets", func() {
|
||||
const numPackets = 3
|
||||
delay := 300 * time.Millisecond
|
||||
opts := &Opts{
|
||||
RemoteAddr: serverConn.LocalAddr().String(),
|
||||
// delay packet 1 by 200 ms
|
||||
// delay packet 2 by 400 ms
|
||||
// ...
|
||||
DelayPacket: func(d Direction, p uint64) time.Duration {
|
||||
if d == DirectionIncoming {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(p) * delay
|
||||
},
|
||||
}
|
||||
startProxy(opts)
|
||||
|
||||
clientReceivedPackets := make(chan packetData, numPackets)
|
||||
// receive the packets echoed by the server on client side
|
||||
go func() {
|
||||
for {
|
||||
buf := make([]byte, protocol.MaxReceivePacketSize)
|
||||
// the ReadFromUDP will error as soon as the UDP conn is closed
|
||||
n, _, err2 := clientConn.ReadFromUDP(buf)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
data := buf[0:n]
|
||||
clientReceivedPackets <- packetData(data)
|
||||
}
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
for i := 1; i <= numPackets; i++ {
|
||||
_, err := clientConn.Write(makePacket(protocol.PacketNumber(i), []byte("foobar"+strconv.Itoa(i))))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
// the packets should have arrived immediately at the server
|
||||
Eventually(serverReceivedPackets).Should(HaveLen(3))
|
||||
expectDelay(start, delay, 0)
|
||||
Eventually(clientReceivedPackets).Should(HaveLen(1))
|
||||
expectDelay(start, delay, 1)
|
||||
Eventually(clientReceivedPackets).Should(HaveLen(2))
|
||||
expectDelay(start, delay, 2)
|
||||
Eventually(clientReceivedPackets).Should(HaveLen(3))
|
||||
expectDelay(start, delay, 3)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,46 @@
|
|||
package testlog
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var (
|
||||
logFileName string // the log file set in the ginkgo flags
|
||||
logFile *os.File
|
||||
)
|
||||
|
||||
// read the logfile command line flag
|
||||
// to set call ginkgo -- -logfile=log.txt
|
||||
func init() {
|
||||
flag.StringVar(&logFileName, "logfile", "", "log file")
|
||||
}
|
||||
|
||||
var _ = BeforeEach(func() {
|
||||
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
|
||||
|
||||
if len(logFileName) > 0 {
|
||||
var err error
|
||||
logFile, err = os.Create(logFileName)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
log.SetOutput(logFile)
|
||||
utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
|
||||
}
|
||||
})
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
if len(logFileName) > 0 {
|
||||
_ = logFile.Close()
|
||||
}
|
||||
})
|
||||
|
||||
// Debug says if this test is being logged
|
||||
func Debug() bool {
|
||||
return len(logFileName) > 0
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
package testserver
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
dataLen = 500 * 1024 // 500 KB
|
||||
dataLenLong = 50 * 1024 * 1024 // 50 MB
|
||||
)
|
||||
|
||||
var (
|
||||
// PRData contains dataLen bytes of pseudo-random data.
|
||||
PRData = GeneratePRData(dataLen)
|
||||
// PRDataLong contains dataLenLong bytes of pseudo-random data.
|
||||
PRDataLong = GeneratePRData(dataLenLong)
|
||||
|
||||
server *h2quic.Server
|
||||
stoppedServing chan struct{}
|
||||
port string
|
||||
)
|
||||
|
||||
func init() {
|
||||
http.HandleFunc("/prdata", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
sl := r.URL.Query().Get("len")
|
||||
if sl != "" {
|
||||
var err error
|
||||
l, err := strconv.Atoi(sl)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = w.Write(GeneratePRData(l))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
} else {
|
||||
_, err := w.Write(PRData)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
http.HandleFunc("/prdatalong", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
_, err := w.Write(PRDataLong)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
http.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
_, err := io.WriteString(w, "Hello, World!\n")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
http.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = w.Write(body)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
}
|
||||
|
||||
// See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
|
||||
func GeneratePRData(l int) []byte {
|
||||
res := make([]byte, l)
|
||||
seed := uint64(1)
|
||||
for i := 0; i < l; i++ {
|
||||
seed = seed * 48271 % 2147483647
|
||||
res[i] = byte(seed)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// StartQuicServer starts a h2quic.Server.
|
||||
// versions is a slice of supported QUIC versions. It may be nil, then all supported versions are used.
|
||||
func StartQuicServer(versions []protocol.VersionNumber) {
|
||||
server = &h2quic.Server{
|
||||
Server: &http.Server{
|
||||
TLSConfig: testdata.GetTLSConfig(),
|
||||
},
|
||||
QuicConfig: &quic.Config{
|
||||
Versions: versions,
|
||||
},
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
port = strconv.Itoa(conn.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
stoppedServing = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.Serve(conn)
|
||||
close(stoppedServing)
|
||||
}()
|
||||
}
|
||||
|
||||
// StopQuicServer stops the h2quic.Server.
|
||||
func StopQuicServer() {
|
||||
Expect(server.Close()).NotTo(HaveOccurred())
|
||||
Eventually(stoppedServing).Should(BeClosed())
|
||||
}
|
||||
|
||||
// Port returns the UDP port of the QUIC server.
|
||||
func Port() string {
|
||||
return port
|
||||
}
|
|
@ -0,0 +1,223 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// The StreamID is the ID of a QUIC stream.
|
||||
type StreamID = protocol.StreamID
|
||||
|
||||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
|
||||
const (
|
||||
// VersionGQUIC39 is gQUIC version 39.
|
||||
VersionGQUIC39 = protocol.Version39
|
||||
// VersionGQUIC43 is gQUIC version 43.
|
||||
VersionGQUIC43 = protocol.Version43
|
||||
// VersionGQUIC44 is gQUIC version 44.
|
||||
VersionGQUIC44 = protocol.Version44
|
||||
// VersionMilestone0_10_0 uses TLS
|
||||
VersionMilestone0_10_0 = protocol.VersionMilestone0_10_0
|
||||
)
|
||||
|
||||
// A Cookie can be used to verify the ownership of the client address.
|
||||
type Cookie = handshake.Cookie
|
||||
|
||||
// ConnectionState records basic details about the QUIC connection.
|
||||
type ConnectionState = handshake.ConnectionState
|
||||
|
||||
// An ErrorCode is an application-defined error code.
|
||||
type ErrorCode = protocol.ApplicationErrorCode
|
||||
|
||||
// Stream is the interface implemented by QUIC streams
|
||||
type Stream interface {
|
||||
// StreamID returns the stream ID.
|
||||
StreamID() StreamID
|
||||
// Read reads data from the stream.
|
||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
io.Reader
|
||||
// Write writes data to the stream.
|
||||
// Write can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
io.Writer
|
||||
// Close closes the write-direction of the stream.
|
||||
// Future calls to Write are not permitted after calling Close.
|
||||
// It must not be called concurrently with Write.
|
||||
// It must not be called after calling CancelWrite.
|
||||
io.Closer
|
||||
// CancelWrite aborts sending on this stream.
|
||||
// It must not be called after Close.
|
||||
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
|
||||
// Write will unblock immediately, and future calls to Write will fail.
|
||||
CancelWrite(ErrorCode) error
|
||||
// CancelRead aborts receiving on this stream.
|
||||
// It will ask the peer to stop transmitting stream data.
|
||||
// Read will unblock immediately, and future Read calls will fail.
|
||||
CancelRead(ErrorCode) error
|
||||
// The context is canceled as soon as the write-side of the stream is closed.
|
||||
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
Context() context.Context
|
||||
// SetReadDeadline sets the deadline for future Read calls and
|
||||
// any currently-blocked Read call.
|
||||
// A zero value for t means Read will not time out.
|
||||
SetReadDeadline(t time.Time) error
|
||||
// SetWriteDeadline sets the deadline for future Write calls
|
||||
// and any currently-blocked Write call.
|
||||
// Even if write times out, it may return n > 0, indicating that
|
||||
// some of the data was successfully written.
|
||||
// A zero value for t means Write will not time out.
|
||||
SetWriteDeadline(t time.Time) error
|
||||
// SetDeadline sets the read and write deadlines associated
|
||||
// with the connection. It is equivalent to calling both
|
||||
// SetReadDeadline and SetWriteDeadline.
|
||||
SetDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A ReceiveStream is a unidirectional Receive Stream.
|
||||
type ReceiveStream interface {
|
||||
// see Stream.StreamID
|
||||
StreamID() StreamID
|
||||
// see Stream.Read
|
||||
io.Reader
|
||||
// see Stream.CancelRead
|
||||
CancelRead(ErrorCode) error
|
||||
// see Stream.SetReadDealine
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A SendStream is a unidirectional Send Stream.
|
||||
type SendStream interface {
|
||||
// see Stream.StreamID
|
||||
StreamID() StreamID
|
||||
// see Stream.Write
|
||||
io.Writer
|
||||
// see Stream.Close
|
||||
io.Closer
|
||||
// see Stream.CancelWrite
|
||||
CancelWrite(ErrorCode) error
|
||||
// see Stream.Context
|
||||
Context() context.Context
|
||||
// see Stream.SetWriteDeadline
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// StreamError is returned by Read and Write when the peer cancels the stream.
|
||||
type StreamError interface {
|
||||
error
|
||||
Canceled() bool
|
||||
ErrorCode() ErrorCode
|
||||
}
|
||||
|
||||
// A Session is a QUIC connection between two peers.
|
||||
type Session interface {
|
||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||
AcceptStream() (Stream, error)
|
||||
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
|
||||
AcceptUniStream() (ReceiveStream, error)
|
||||
// OpenStream opens a new bidirectional QUIC stream.
|
||||
// It returns a special error when the peer's concurrent stream limit is reached.
|
||||
// There is no signaling to the peer about new streams:
|
||||
// The peer can only accept the stream after data has been sent on the stream.
|
||||
// TODO(#1152): Enable testing for the special error
|
||||
OpenStream() (Stream, error)
|
||||
// OpenStreamSync opens a new bidirectional QUIC stream.
|
||||
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
|
||||
OpenStreamSync() (Stream, error)
|
||||
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
|
||||
// It returns a special error when the peer's concurrent stream limit is reached.
|
||||
// TODO(#1152): Enable testing for the special error
|
||||
OpenUniStream() (SendStream, error)
|
||||
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
|
||||
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
|
||||
OpenUniStreamSync() (SendStream, error)
|
||||
// LocalAddr returns the local address.
|
||||
LocalAddr() net.Addr
|
||||
// RemoteAddr returns the address of the peer.
|
||||
RemoteAddr() net.Addr
|
||||
// Close the connection.
|
||||
io.Closer
|
||||
// Close the connection with an error.
|
||||
// The error must not be nil.
|
||||
CloseWithError(ErrorCode, error) error
|
||||
// The context is cancelled when the session is closed.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
Context() context.Context
|
||||
// ConnectionState returns basic details about the QUIC connection.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
ConnectionState() ConnectionState
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
type Config struct {
|
||||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
// Warning: This API should not be considered stable and will change soon.
|
||||
Versions []VersionNumber
|
||||
// Ask the server to omit the connection ID sent in the Public Header.
|
||||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
||||
// Currently only valid for the client.
|
||||
RequestConnectionIDOmission bool
|
||||
// The length of the connection ID in bytes. Only valid for IETF QUIC.
|
||||
// It can be 0, or any value between 4 and 18.
|
||||
// If not set, the interpretation depends on where the Config is used:
|
||||
// If used for dialing an address, a 0 byte connection ID will be used.
|
||||
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
|
||||
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
|
||||
ConnectionIDLength int
|
||||
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
|
||||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 10 seconds.
|
||||
HandshakeTimeout time.Duration
|
||||
// IdleTimeout is the maximum duration that may pass without any incoming network activity.
|
||||
// This value only applies after the handshake has completed.
|
||||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 30 seconds.
|
||||
IdleTimeout time.Duration
|
||||
// AcceptCookie determines if a Cookie is accepted.
|
||||
// It is called with cookie = nil if the client didn't send an Cookie.
|
||||
// If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours.
|
||||
// This option is only valid for the server.
|
||||
AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool
|
||||
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
|
||||
MaxReceiveStreamFlowControlWindow uint64
|
||||
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
|
||||
MaxReceiveConnectionFlowControlWindow uint64
|
||||
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any bidirectional streams.
|
||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||
MaxIncomingStreams int
|
||||
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
||||
// This value doesn't have any effect in Google QUIC.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||
MaxIncomingUniStreams int
|
||||
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
|
||||
KeepAlive bool
|
||||
}
|
||||
|
||||
// A Listener for incoming QUIC connections
|
||||
type Listener interface {
|
||||
// Close the server, sending CONNECTION_CLOSE frames to each peer.
|
||||
Close() error
|
||||
// Addr returns the local network addr that the server is listening on.
|
||||
Addr() net.Addr
|
||||
// Accept returns new sessions. It should be called in a loop.
|
||||
Accept() (Session, error)
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCrypto(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "AckHandler Suite")
|
||||
}
|
||||
|
||||
var mockCtrl *gomock.Controller
|
||||
|
||||
var _ = BeforeEach(func() {
|
||||
mockCtrl = gomock.NewController(GinkgoT())
|
||||
})
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
mockCtrl.Finish()
|
||||
})
|
|
@ -0,0 +1,3 @@
|
|||
package ackhandler
|
||||
|
||||
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet
|
|
@ -0,0 +1,47 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(packet *Packet)
|
||||
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
|
||||
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
|
||||
SetHandshakeComplete()
|
||||
|
||||
// The SendMode determines if and what kind of packets can be sent.
|
||||
SendMode() SendMode
|
||||
// TimeUntilSend is the time when the next packet should be sent.
|
||||
// It is used for pacing packets.
|
||||
TimeUntilSend() time.Time
|
||||
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
|
||||
// It always returns a number greater or equal than 1.
|
||||
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
|
||||
// Note that the number of packets is only calculated based on the pacing algorithm.
|
||||
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
||||
ShouldSendNumPackets() int
|
||||
|
||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
DequeuePacketForRetransmission() *Packet
|
||||
DequeueProbePacket() (*Packet, error)
|
||||
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
OnAlarm() error
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
type ReceivedPacketHandler interface {
|
||||
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
|
||||
IgnoreBelow(protocol.PacketNumber)
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
GetAckFrame() *wire.AckFrame
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// A Packet is a packet
|
||||
type Packet struct {
|
||||
PacketNumber protocol.PacketNumber
|
||||
PacketType protocol.PacketType
|
||||
Frames []wire.Frame
|
||||
Length protocol.ByteCount
|
||||
EncryptionLevel protocol.EncryptionLevel
|
||||
SendTime time.Time
|
||||
|
||||
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
|
||||
|
||||
// There are two reasons why a packet cannot be retransmitted:
|
||||
// * it was already retransmitted
|
||||
// * this packet is a retransmission, and we already received an ACK for the original packet
|
||||
canBeRetransmitted bool
|
||||
includedInBytesInFlight bool
|
||||
retransmittedAs []protocol.PacketNumber
|
||||
isRetransmission bool // we need a separate bool here because 0 is a valid packet number
|
||||
retransmissionOf protocol.PacketNumber
|
||||
}
|
|
@ -0,0 +1,217 @@
|
|||
// This file was automatically generated by genny.
|
||||
// Any changes will be lost if this file is regenerated.
|
||||
// see https://github.com/cheekybits/genny
|
||||
|
||||
package ackhandler
|
||||
|
||||
// Linked list implementation from the Go standard library.
|
||||
|
||||
// PacketElement is an element of a linked list.
|
||||
type PacketElement struct {
|
||||
// Next and previous pointers in the doubly-linked list of elements.
|
||||
// To simplify the implementation, internally a list l is implemented
|
||||
// as a ring, such that &l.root is both the next element of the last
|
||||
// list element (l.Back()) and the previous element of the first list
|
||||
// element (l.Front()).
|
||||
next, prev *PacketElement
|
||||
|
||||
// The list to which this element belongs.
|
||||
list *PacketList
|
||||
|
||||
// The value stored with this element.
|
||||
Value Packet
|
||||
}
|
||||
|
||||
// Next returns the next list element or nil.
|
||||
func (e *PacketElement) Next() *PacketElement {
|
||||
if p := e.next; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prev returns the previous list element or nil.
|
||||
func (e *PacketElement) Prev() *PacketElement {
|
||||
if p := e.prev; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PacketList is a linked list of Packets.
|
||||
type PacketList struct {
|
||||
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list length excluding (this) sentinel element
|
||||
}
|
||||
|
||||
// Init initializes or clears list l.
|
||||
func (l *PacketList) Init() *PacketList {
|
||||
l.root.next = &l.root
|
||||
l.root.prev = &l.root
|
||||
l.len = 0
|
||||
return l
|
||||
}
|
||||
|
||||
// NewPacketList returns an initialized list.
|
||||
func NewPacketList() *PacketList { return new(PacketList).Init() }
|
||||
|
||||
// Len returns the number of elements of list l.
|
||||
// The complexity is O(1).
|
||||
func (l *PacketList) Len() int { return l.len }
|
||||
|
||||
// Front returns the first element of list l or nil if the list is empty.
|
||||
func (l *PacketList) Front() *PacketElement {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.next
|
||||
}
|
||||
|
||||
// Back returns the last element of list l or nil if the list is empty.
|
||||
func (l *PacketList) Back() *PacketElement {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
// lazyInit lazily initializes a zero List value.
|
||||
func (l *PacketList) lazyInit() {
|
||||
if l.root.next == nil {
|
||||
l.Init()
|
||||
}
|
||||
}
|
||||
|
||||
// insert inserts e after at, increments l.len, and returns e.
|
||||
func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
|
||||
n := at.next
|
||||
at.next = e
|
||||
e.prev = at
|
||||
e.next = n
|
||||
n.prev = e
|
||||
e.list = l
|
||||
l.len++
|
||||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
|
||||
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
|
||||
return l.insert(&PacketElement{Value: v}, at)
|
||||
}
|
||||
|
||||
// remove removes e from its list, decrements l.len, and returns e.
|
||||
func (l *PacketList) remove(e *PacketElement) *PacketElement {
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.next = nil // avoid memory leaks
|
||||
e.prev = nil // avoid memory leaks
|
||||
e.list = nil
|
||||
l.len--
|
||||
return e
|
||||
}
|
||||
|
||||
// Remove removes e from l if e is an element of list l.
|
||||
// It returns the element value e.Value.
|
||||
// The element must not be nil.
|
||||
func (l *PacketList) Remove(e *PacketElement) Packet {
|
||||
if e.list == l {
|
||||
// if e.list == l, l must have been initialized when e was inserted
|
||||
// in l or l == nil (e is a zero Element) and l.remove will crash
|
||||
l.remove(e)
|
||||
}
|
||||
return e.Value
|
||||
}
|
||||
|
||||
// PushFront inserts a new element e with value v at the front of list l and returns e.
|
||||
func (l *PacketList) PushFront(v Packet) *PacketElement {
|
||||
l.lazyInit()
|
||||
return l.insertValue(v, &l.root)
|
||||
}
|
||||
|
||||
// PushBack inserts a new element e with value v at the back of list l and returns e.
|
||||
func (l *PacketList) PushBack(v Packet) *PacketElement {
|
||||
l.lazyInit()
|
||||
return l.insertValue(v, l.root.prev)
|
||||
}
|
||||
|
||||
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
// The mark must not be nil.
|
||||
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark.prev)
|
||||
}
|
||||
|
||||
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
// The mark must not be nil.
|
||||
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *PacketList) MoveToFront(e *PacketElement) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.insert(l.remove(e), &l.root)
|
||||
}
|
||||
|
||||
// MoveToBack moves element e to the back of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *PacketList) MoveToBack(e *PacketElement) {
|
||||
if e.list != l || l.root.prev == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.insert(l.remove(e), l.root.prev)
|
||||
}
|
||||
|
||||
// MoveBefore moves element e to its new position before mark.
|
||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||
// The element and mark must not be nil.
|
||||
func (l *PacketList) MoveBefore(e, mark *PacketElement) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
}
|
||||
l.insert(l.remove(e), mark.prev)
|
||||
}
|
||||
|
||||
// MoveAfter moves element e to its new position after mark.
|
||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||
// The element and mark must not be nil.
|
||||
func (l *PacketList) MoveAfter(e, mark *PacketElement) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
}
|
||||
l.insert(l.remove(e), mark)
|
||||
}
|
||||
|
||||
// PushBackList inserts a copy of an other list at the back of list l.
|
||||
// The lists l and other may be the same. They must not be nil.
|
||||
func (l *PacketList) PushBackList(other *PacketList) {
|
||||
l.lazyInit()
|
||||
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
|
||||
l.insertValue(e.Value, l.root.prev)
|
||||
}
|
||||
}
|
||||
|
||||
// PushFrontList inserts a copy of an other list at the front of list l.
|
||||
// The lists l and other may be the same. They must not be nil.
|
||||
func (l *PacketList) PushFrontList(other *PacketList) {
|
||||
l.lazyInit()
|
||||
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
|
||||
l.insertValue(e.Value, &l.root)
|
||||
}
|
||||
}
|
215
vendor/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go
vendored
Normal file
215
vendor/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go
vendored
Normal file
|
@ -0,0 +1,215 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type receivedPacketHandler struct {
|
||||
largestObserved protocol.PacketNumber
|
||||
ignoreBelow protocol.PacketNumber
|
||||
largestObservedReceivedTime time.Time
|
||||
|
||||
packetHistory *receivedPacketHistory
|
||||
|
||||
ackSendDelay time.Duration
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
packetsReceivedSinceLastAck int
|
||||
retransmittablePacketsReceivedSinceLastAck int
|
||||
ackQueued bool
|
||||
ackAlarm time.Time
|
||||
lastAck *wire.AckFrame
|
||||
|
||||
logger utils.Logger
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
const (
|
||||
// maximum delay that can be applied to an ACK for a retransmittable packet
|
||||
ackSendDelay = 25 * time.Millisecond
|
||||
// initial maximum number of retransmittable packets received before sending an ack.
|
||||
initialRetransmittablePacketsBeforeAck = 2
|
||||
// number of retransmittable that an ACK is sent for
|
||||
retransmittablePacketsBeforeAck = 10
|
||||
// 1/5 RTT delay when doing ack decimation
|
||||
ackDecimationDelay = 1.0 / 4
|
||||
// 1/8 RTT delay when doing ack decimation
|
||||
shortAckDecimationDelay = 1.0 / 8
|
||||
// Minimum number of packets received before ack decimation is enabled.
|
||||
// This intends to avoid the beginning of slow start, when CWNDs may be
|
||||
// rapidly increasing.
|
||||
minReceivedBeforeAckDecimation = 100
|
||||
// Maximum number of packets to ack immediately after a missing packet for
|
||||
// fast retransmission to kick in at the sender. This limit is created to
|
||||
// reduce the number of acks sent that have no benefit for fast retransmission.
|
||||
// Set to the number of nacks needed for fast retransmit plus one for protection
|
||||
// against an ack loss
|
||||
maxPacketsAfterNewMissing = 4
|
||||
)
|
||||
|
||||
// NewReceivedPacketHandler creates a new receivedPacketHandler
|
||||
func NewReceivedPacketHandler(
|
||||
rttStats *congestion.RTTStats,
|
||||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
) ReceivedPacketHandler {
|
||||
return &receivedPacketHandler{
|
||||
packetHistory: newReceivedPacketHistory(),
|
||||
ackSendDelay: ackSendDelay,
|
||||
rttStats: rttStats,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
|
||||
if packetNumber < h.ignoreBelow {
|
||||
return nil
|
||||
}
|
||||
|
||||
isMissing := h.isMissing(packetNumber)
|
||||
if packetNumber > h.largestObserved {
|
||||
h.largestObserved = packetNumber
|
||||
h.largestObservedReceivedTime = rcvTime
|
||||
}
|
||||
|
||||
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IgnoreBelow sets a lower limit for acking packets.
|
||||
// Packets with packet numbers smaller than p will not be acked.
|
||||
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
|
||||
if p <= h.ignoreBelow {
|
||||
return
|
||||
}
|
||||
h.ignoreBelow = p
|
||||
h.packetHistory.DeleteBelow(p)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
|
||||
}
|
||||
}
|
||||
|
||||
// isMissing says if a packet was reported missing in the last ACK.
|
||||
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
|
||||
if h.lastAck == nil || p < h.ignoreBelow {
|
||||
return false
|
||||
}
|
||||
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
|
||||
if h.lastAck == nil {
|
||||
return false
|
||||
}
|
||||
highestRange := h.packetHistory.GetHighestAckRange()
|
||||
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
|
||||
}
|
||||
|
||||
// maybeQueueAck queues an ACK, if necessary.
|
||||
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
|
||||
// in ACK_DECIMATION_WITH_REORDERING mode.
|
||||
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
|
||||
h.packetsReceivedSinceLastAck++
|
||||
|
||||
// always ack the first packet
|
||||
if h.lastAck == nil {
|
||||
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
|
||||
h.ackQueued = true
|
||||
return
|
||||
}
|
||||
|
||||
// Send an ACK if this packet was reported missing in an ACK sent before.
|
||||
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
||||
// missing packets we reported in the previous ack, send an ACK immediately.
|
||||
if wasMissing {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
|
||||
}
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
if !h.ackQueued && shouldInstigateAck {
|
||||
h.retransmittablePacketsReceivedSinceLastAck++
|
||||
|
||||
if packetNumber > minReceivedBeforeAckDecimation {
|
||||
// ack up to 10 packets at once
|
||||
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
|
||||
h.ackQueued = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
|
||||
}
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
|
||||
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
|
||||
h.ackAlarm = rcvTime.Add(ackDelay)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// send an ACK every 2 retransmittable packets
|
||||
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
|
||||
}
|
||||
h.ackQueued = true
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
|
||||
}
|
||||
h.ackAlarm = rcvTime.Add(ackSendDelay)
|
||||
}
|
||||
}
|
||||
// If there are new missing packets to report, set a short timer to send an ACK.
|
||||
if h.hasNewMissingPackets() {
|
||||
// wait the minimum of 1/8 min RTT and the existing ack time
|
||||
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
|
||||
ackTime := rcvTime.Add(ackDelay)
|
||||
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
|
||||
h.ackAlarm = ackTime
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.ackQueued {
|
||||
// cancel the ack alarm
|
||||
h.ackAlarm = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
|
||||
now := time.Now()
|
||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
|
||||
return nil
|
||||
}
|
||||
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
|
||||
h.logger.Debugf("Sending ACK because the ACK timer expired.")
|
||||
}
|
||||
|
||||
ack := &wire.AckFrame{
|
||||
AckRanges: h.packetHistory.GetAckRanges(),
|
||||
DelayTime: now.Sub(h.largestObservedReceivedTime),
|
||||
}
|
||||
|
||||
h.lastAck = ack
|
||||
h.ackAlarm = time.Time{}
|
||||
h.ackQueued = false
|
||||
h.packetsReceivedSinceLastAck = 0
|
||||
h.retransmittablePacketsReceivedSinceLastAck = 0
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
382
vendor/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler_test.go
vendored
Normal file
382
vendor/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler_test.go
vendored
Normal file
|
@ -0,0 +1,382 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("receivedPacketHandler", func() {
|
||||
var (
|
||||
handler *receivedPacketHandler
|
||||
rttStats *congestion.RTTStats
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
rttStats = &congestion.RTTStats{}
|
||||
handler = NewReceivedPacketHandler(rttStats, utils.DefaultLogger, protocol.VersionWhatever).(*receivedPacketHandler)
|
||||
})
|
||||
|
||||
Context("accepting packets", func() {
|
||||
It("handles a packet that arrives late", func() {
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(1), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(protocol.PacketNumber(3), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(protocol.PacketNumber(2), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("saves the time when each packet arrived", func() {
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(3), time.Now(), true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
|
||||
})
|
||||
|
||||
It("updates the largestObserved and the largestObservedReceivedTime", func() {
|
||||
now := time.Now()
|
||||
handler.largestObserved = 3
|
||||
handler.largestObservedReceivedTime = now.Add(-1 * time.Second)
|
||||
err := handler.ReceivedPacket(5, now, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
|
||||
Expect(handler.largestObservedReceivedTime).To(Equal(now))
|
||||
})
|
||||
|
||||
It("doesn't update the largestObserved and the largestObservedReceivedTime for a belated packet", func() {
|
||||
now := time.Now()
|
||||
timestamp := now.Add(-1 * time.Second)
|
||||
handler.largestObserved = 5
|
||||
handler.largestObservedReceivedTime = timestamp
|
||||
err := handler.ReceivedPacket(4, now, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
|
||||
Expect(handler.largestObservedReceivedTime).To(Equal(timestamp))
|
||||
})
|
||||
|
||||
It("passes on errors from receivedPacketHistory", func() {
|
||||
var err error
|
||||
for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ {
|
||||
err = handler.ReceivedPacket(2*i+1, time.Time{}, true)
|
||||
// this will eventually return an error
|
||||
// details about when exactly the receivedPacketHistory errors are tested there
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
|
||||
})
|
||||
})
|
||||
|
||||
Context("ACKs", func() {
|
||||
Context("queueing ACKs", func() {
|
||||
receiveAndAck10Packets := func() {
|
||||
for i := 1; i <= 10; i++ {
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
}
|
||||
|
||||
receiveAndAckPacketsUntilAckDecimation := func() {
|
||||
for i := 1; i <= minReceivedBeforeAckDecimation; i++ {
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
}
|
||||
|
||||
It("always queues an ACK for the first packet", func() {
|
||||
err := handler.ReceivedPacket(1, time.Time{}, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeTrue())
|
||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
||||
})
|
||||
|
||||
It("works with packet number 0", func() {
|
||||
err := handler.ReceivedPacket(0, time.Time{}, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeTrue())
|
||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
||||
})
|
||||
|
||||
It("queues an ACK for every second retransmittable packet at the beginning", func() {
|
||||
receiveAndAck10Packets()
|
||||
p := protocol.PacketNumber(11)
|
||||
for i := 0; i <= 20; i++ {
|
||||
err := handler.ReceivedPacket(p, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
p++
|
||||
err = handler.ReceivedPacket(p, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeTrue())
|
||||
p++
|
||||
// dequeue the ACK frame
|
||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
||||
}
|
||||
})
|
||||
|
||||
It("queues an ACK for every 10 retransmittable packet, if they are arriving fast", func() {
|
||||
receiveAndAck10Packets()
|
||||
p := protocol.PacketNumber(10000)
|
||||
for i := 0; i < 9; i++ {
|
||||
err := handler.ReceivedPacket(p, time.Now(), true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
p++
|
||||
}
|
||||
Expect(handler.GetAlarmTimeout()).NotTo(BeZero())
|
||||
err := handler.ReceivedPacket(p, time.Now(), true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeTrue())
|
||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
||||
})
|
||||
|
||||
It("only sets the timer when receiving a retransmittable packets", func() {
|
||||
receiveAndAck10Packets()
|
||||
err := handler.ReceivedPacket(11, time.Now(), false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
||||
rcvTime := time.Now().Add(10 * time.Millisecond)
|
||||
err = handler.ReceivedPacket(12, rcvTime, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
Expect(handler.GetAlarmTimeout()).To(Equal(rcvTime.Add(ackSendDelay)))
|
||||
})
|
||||
|
||||
It("queues an ACK if it was reported missing before", func() {
|
||||
receiveAndAck10Packets()
|
||||
err := handler.ReceivedPacket(11, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(13, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame() // ACK: 1-11 and 13, missing: 12
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(ack.HasMissingRanges()).To(BeTrue())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
err = handler.ReceivedPacket(12, time.Time{}, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't queue an ACK if it was reported missing before, but is below the threshold", func() {
|
||||
receiveAndAck10Packets()
|
||||
// 11 is missing
|
||||
err := handler.ReceivedPacket(12, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(13, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame() // ACK: 1-10, 12-13
|
||||
Expect(ack).ToNot(BeNil())
|
||||
// now receive 11
|
||||
handler.IgnoreBelow(12)
|
||||
err = handler.ReceivedPacket(11, time.Time{}, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack = handler.GetAckFrame()
|
||||
Expect(ack).To(BeNil())
|
||||
})
|
||||
|
||||
It("doesn't queue an ACK if the packet closes a gap that was not yet reported", func() {
|
||||
receiveAndAckPacketsUntilAckDecimation()
|
||||
p := protocol.PacketNumber(minReceivedBeforeAckDecimation + 1)
|
||||
err := handler.ReceivedPacket(p+1, time.Now(), true) // p is missing now
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
Expect(handler.GetAlarmTimeout()).ToNot(BeZero())
|
||||
err = handler.ReceivedPacket(p, time.Now(), true) // p is not missing any more
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
})
|
||||
|
||||
It("sets an ACK alarm after 1/4 RTT if it creates a new missing range", func() {
|
||||
now := time.Now().Add(-time.Hour)
|
||||
rtt := 80 * time.Millisecond
|
||||
rttStats.UpdateRTT(rtt, 0, now)
|
||||
receiveAndAckPacketsUntilAckDecimation()
|
||||
p := protocol.PacketNumber(minReceivedBeforeAckDecimation + 1)
|
||||
for i := p; i < p+6; i++ {
|
||||
err := handler.ReceivedPacket(i, now, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := handler.ReceivedPacket(p+10, now, true) // we now know that packets p+7, p+8 and p+9
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rttStats.MinRTT()).To(Equal(rtt))
|
||||
Expect(handler.ackAlarm.Sub(now)).To(Equal(rtt / 8))
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack.HasMissingRanges()).To(BeTrue())
|
||||
Expect(ack).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("ACK generation", func() {
|
||||
BeforeEach(func() {
|
||||
handler.ackQueued = true
|
||||
})
|
||||
|
||||
It("generates a simple ACK frame", func() {
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(2, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(2)))
|
||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(ack.HasMissingRanges()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("generates an ACK for packet number 0", func() {
|
||||
err := handler.ReceivedPacket(0, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(0)))
|
||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0)))
|
||||
Expect(ack.HasMissingRanges()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("sets the delay time", func() {
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(2, time.Now().Add(-1337*time.Millisecond), true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(ack.DelayTime).To(BeNumerically("~", 1337*time.Millisecond, 50*time.Millisecond))
|
||||
})
|
||||
|
||||
It("saves the last sent ACK", func() {
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(handler.lastAck).To(Equal(ack))
|
||||
err = handler.ReceivedPacket(2, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.ackQueued = true
|
||||
ack = handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(handler.lastAck).To(Equal(ack))
|
||||
})
|
||||
|
||||
It("generates an ACK frame with missing packets", func() {
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(4, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(4)))
|
||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(ack.AckRanges).To(Equal([]wire.AckRange{
|
||||
{Smallest: 4, Largest: 4},
|
||||
{Smallest: 1, Largest: 1},
|
||||
}))
|
||||
})
|
||||
|
||||
It("generates an ACK for packet number 0 and other packets", func() {
|
||||
err := handler.ReceivedPacket(0, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(3, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3)))
|
||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(0)))
|
||||
Expect(ack.AckRanges).To(Equal([]wire.AckRange{
|
||||
{Smallest: 3, Largest: 3},
|
||||
{Smallest: 0, Largest: 1},
|
||||
}))
|
||||
})
|
||||
|
||||
It("accepts packets below the lower limit", func() {
|
||||
handler.IgnoreBelow(6)
|
||||
err := handler.ReceivedPacket(2, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("doesn't add delayed packets to the packetHistory", func() {
|
||||
handler.IgnoreBelow(7)
|
||||
err := handler.ReceivedPacket(4, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.ReceivedPacket(10, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10)))
|
||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(10)))
|
||||
})
|
||||
|
||||
It("deletes packets from the packetHistory when a lower limit is set", func() {
|
||||
for i := 1; i <= 12; i++ {
|
||||
err := handler.ReceivedPacket(protocol.PacketNumber(i), time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
handler.IgnoreBelow(7)
|
||||
// check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(12)))
|
||||
Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(7)))
|
||||
Expect(ack.HasMissingRanges()).To(BeFalse())
|
||||
})
|
||||
|
||||
// TODO: remove this test when dropping support for STOP_WAITINGs
|
||||
It("handles a lower limit of 0", func() {
|
||||
handler.IgnoreBelow(0)
|
||||
err := handler.ReceivedPacket(1337, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
ack := handler.GetAckFrame()
|
||||
Expect(ack).ToNot(BeNil())
|
||||
Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(1337)))
|
||||
})
|
||||
|
||||
It("resets all counters needed for the ACK queueing decision when sending an ACK", func() {
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.ackAlarm = time.Now().Add(-time.Minute)
|
||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
||||
Expect(handler.packetsReceivedSinceLastAck).To(BeZero())
|
||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
||||
Expect(handler.retransmittablePacketsReceivedSinceLastAck).To(BeZero())
|
||||
Expect(handler.ackQueued).To(BeFalse())
|
||||
})
|
||||
|
||||
It("doesn't generate an ACK when none is queued and the timer is not set", func() {
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.ackQueued = false
|
||||
handler.ackAlarm = time.Time{}
|
||||
Expect(handler.GetAckFrame()).To(BeNil())
|
||||
})
|
||||
|
||||
It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() {
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.ackQueued = false
|
||||
handler.ackAlarm = time.Now().Add(time.Minute)
|
||||
Expect(handler.GetAckFrame()).To(BeNil())
|
||||
})
|
||||
|
||||
It("generates an ACK when the timer has expired", func() {
|
||||
err := handler.ReceivedPacket(1, time.Time{}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler.ackQueued = false
|
||||
handler.ackAlarm = time.Now().Add(-time.Minute)
|
||||
Expect(handler.GetAckFrame()).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
121
vendor/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go
vendored
Normal file
121
vendor/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go
vendored
Normal file
|
@ -0,0 +1,121 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// The receivedPacketHistory stores if a packet number has already been received.
|
||||
// It does not store packet contents.
|
||||
type receivedPacketHistory struct {
|
||||
ranges *utils.PacketIntervalList
|
||||
|
||||
lowestInReceivedPacketNumbers protocol.PacketNumber
|
||||
}
|
||||
|
||||
var errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
|
||||
|
||||
// newReceivedPacketHistory creates a new received packet history
|
||||
func newReceivedPacketHistory() *receivedPacketHistory {
|
||||
return &receivedPacketHistory{
|
||||
ranges: utils.NewPacketIntervalList(),
|
||||
}
|
||||
}
|
||||
|
||||
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
|
||||
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
|
||||
if h.ranges.Len() >= protocol.MaxTrackedReceivedAckRanges {
|
||||
return errTooManyOutstandingReceivedAckRanges
|
||||
}
|
||||
|
||||
if h.ranges.Len() == 0 {
|
||||
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
|
||||
return nil
|
||||
}
|
||||
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
// p already included in an existing range. Nothing to do here
|
||||
if p >= el.Value.Start && p <= el.Value.End {
|
||||
return nil
|
||||
}
|
||||
|
||||
var rangeExtended bool
|
||||
if el.Value.End == p-1 { // extend a range at the end
|
||||
rangeExtended = true
|
||||
el.Value.End = p
|
||||
} else if el.Value.Start == p+1 { // extend a range at the beginning
|
||||
rangeExtended = true
|
||||
el.Value.Start = p
|
||||
}
|
||||
|
||||
// if a range was extended (either at the beginning or at the end, maybe it is possible to merge two ranges into one)
|
||||
if rangeExtended {
|
||||
prev := el.Prev()
|
||||
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
|
||||
prev.Value.End = el.Value.End
|
||||
h.ranges.Remove(el)
|
||||
return nil
|
||||
}
|
||||
return nil // if the two ranges were not merge, we're done here
|
||||
}
|
||||
|
||||
// create a new range at the end
|
||||
if p > el.Value.End {
|
||||
h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// create a new range at the beginning
|
||||
h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteBelow deletes all entries below (but not including) p
|
||||
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
|
||||
if p <= h.lowestInReceivedPacketNumbers {
|
||||
return
|
||||
}
|
||||
h.lowestInReceivedPacketNumbers = p
|
||||
|
||||
nextEl := h.ranges.Front()
|
||||
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
|
||||
nextEl = el.Next()
|
||||
|
||||
if p > el.Value.Start && p <= el.Value.End {
|
||||
el.Value.Start = p
|
||||
} else if el.Value.End < p { // delete a whole range
|
||||
h.ranges.Remove(el)
|
||||
} else { // no ranges affected. Nothing to do
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
|
||||
func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
|
||||
if h.ranges.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ackRanges := make([]wire.AckRange, h.ranges.Len())
|
||||
i := 0
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}
|
||||
i++
|
||||
}
|
||||
return ackRanges
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
|
||||
ackRange := wire.AckRange{}
|
||||
if h.ranges.Len() > 0 {
|
||||
r := h.ranges.Back().Value
|
||||
ackRange.Smallest = r.Start
|
||||
ackRange.Largest = r.End
|
||||
}
|
||||
return ackRange
|
||||
}
|
248
vendor/lucas-clemente/quic-go/internal/ackhandler/received_packet_history_test.go
vendored
Normal file
248
vendor/lucas-clemente/quic-go/internal/ackhandler/received_packet_history_test.go
vendored
Normal file
|
@ -0,0 +1,248 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("receivedPacketHistory", func() {
|
||||
var (
|
||||
hist *receivedPacketHistory
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
hist = newReceivedPacketHistory()
|
||||
})
|
||||
|
||||
Context("ranges", func() {
|
||||
It("adds the first packet", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
Expect(hist.ranges.Len()).To(Equal(1))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
||||
})
|
||||
|
||||
It("doesn't care about duplicate packets", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
Expect(hist.ranges.Len()).To(Equal(1))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
||||
})
|
||||
|
||||
It("adds a few consecutive packets", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(5)
|
||||
hist.ReceivedPacket(6)
|
||||
Expect(hist.ranges.Len()).To(Equal(1))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
|
||||
})
|
||||
|
||||
It("doesn't care about a duplicate packet contained in an existing range", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(5)
|
||||
hist.ReceivedPacket(6)
|
||||
hist.ReceivedPacket(5)
|
||||
Expect(hist.ranges.Len()).To(Equal(1))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
|
||||
})
|
||||
|
||||
It("extends a range at the front", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(3)
|
||||
Expect(hist.ranges.Len()).To(Equal(1))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 3, End: 4}))
|
||||
})
|
||||
|
||||
It("creates a new range when a packet is lost", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(6)
|
||||
Expect(hist.ranges.Len()).To(Equal(2))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6}))
|
||||
})
|
||||
|
||||
It("creates a new range in between two ranges", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(10)
|
||||
Expect(hist.ranges.Len()).To(Equal(2))
|
||||
hist.ReceivedPacket(7)
|
||||
Expect(hist.ranges.Len()).To(Equal(3))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
||||
Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7}))
|
||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
|
||||
})
|
||||
|
||||
It("creates a new range before an existing range for a belated packet", func() {
|
||||
hist.ReceivedPacket(6)
|
||||
hist.ReceivedPacket(4)
|
||||
Expect(hist.ranges.Len()).To(Equal(2))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 6}))
|
||||
})
|
||||
|
||||
It("extends a previous range at the end", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(7)
|
||||
hist.ReceivedPacket(5)
|
||||
Expect(hist.ranges.Len()).To(Equal(2))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 5}))
|
||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 7, End: 7}))
|
||||
})
|
||||
|
||||
It("extends a range at the front", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(7)
|
||||
hist.ReceivedPacket(6)
|
||||
Expect(hist.ranges.Len()).To(Equal(2))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 6, End: 7}))
|
||||
})
|
||||
|
||||
It("closes a range", func() {
|
||||
hist.ReceivedPacket(6)
|
||||
hist.ReceivedPacket(4)
|
||||
Expect(hist.ranges.Len()).To(Equal(2))
|
||||
hist.ReceivedPacket(5)
|
||||
Expect(hist.ranges.Len()).To(Equal(1))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
|
||||
})
|
||||
|
||||
It("closes a range in the middle", func() {
|
||||
hist.ReceivedPacket(1)
|
||||
hist.ReceivedPacket(10)
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(6)
|
||||
Expect(hist.ranges.Len()).To(Equal(4))
|
||||
hist.ReceivedPacket(5)
|
||||
Expect(hist.ranges.Len()).To(Equal(3))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 1, End: 1}))
|
||||
Expect(hist.ranges.Front().Next().Value).To(Equal(utils.PacketInterval{Start: 4, End: 6}))
|
||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("deleting", func() {
|
||||
It("does nothing when the history is empty", func() {
|
||||
hist.DeleteBelow(5)
|
||||
Expect(hist.ranges.Len()).To(BeZero())
|
||||
})
|
||||
|
||||
It("deletes a range", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(5)
|
||||
hist.ReceivedPacket(10)
|
||||
hist.DeleteBelow(6)
|
||||
Expect(hist.ranges.Len()).To(Equal(1))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
|
||||
})
|
||||
|
||||
It("deletes multiple ranges", func() {
|
||||
hist.ReceivedPacket(1)
|
||||
hist.ReceivedPacket(5)
|
||||
hist.ReceivedPacket(10)
|
||||
hist.DeleteBelow(8)
|
||||
Expect(hist.ranges.Len()).To(Equal(1))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
|
||||
})
|
||||
|
||||
It("adjusts a range, if packets are delete from an existing range", func() {
|
||||
hist.ReceivedPacket(3)
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(5)
|
||||
hist.ReceivedPacket(6)
|
||||
hist.ReceivedPacket(7)
|
||||
hist.DeleteBelow(5)
|
||||
Expect(hist.ranges.Len()).To(Equal(1))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 7}))
|
||||
})
|
||||
|
||||
It("adjusts a range, if only one packet remains in the range", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(5)
|
||||
hist.ReceivedPacket(10)
|
||||
hist.DeleteBelow(5)
|
||||
Expect(hist.ranges.Len()).To(Equal(2))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 5, End: 5}))
|
||||
Expect(hist.ranges.Back().Value).To(Equal(utils.PacketInterval{Start: 10, End: 10}))
|
||||
})
|
||||
|
||||
It("keeps a one-packet range, if deleting up to the packet directly below", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.DeleteBelow(4)
|
||||
Expect(hist.ranges.Len()).To(Equal(1))
|
||||
Expect(hist.ranges.Front().Value).To(Equal(utils.PacketInterval{Start: 4, End: 4}))
|
||||
})
|
||||
|
||||
Context("DoS protection", func() {
|
||||
It("doesn't create more than MaxTrackedReceivedAckRanges ranges", func() {
|
||||
for i := protocol.PacketNumber(1); i <= protocol.MaxTrackedReceivedAckRanges; i++ {
|
||||
err := hist.ReceivedPacket(2 * i)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 2)
|
||||
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
|
||||
})
|
||||
|
||||
It("doesn't consider already deleted ranges for MaxTrackedReceivedAckRanges", func() {
|
||||
for i := protocol.PacketNumber(1); i <= protocol.MaxTrackedReceivedAckRanges; i++ {
|
||||
err := hist.ReceivedPacket(2 * i)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 2)
|
||||
Expect(err).To(MatchError(errTooManyOutstandingReceivedAckRanges))
|
||||
hist.DeleteBelow(protocol.MaxTrackedReceivedAckRanges) // deletes about half of the ranges
|
||||
err = hist.ReceivedPacket(2*protocol.MaxTrackedReceivedAckRanges + 4)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("ACK range export", func() {
|
||||
It("returns nil if there are no ranges", func() {
|
||||
Expect(hist.GetAckRanges()).To(BeNil())
|
||||
})
|
||||
|
||||
It("gets a single ACK range", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(5)
|
||||
ackRanges := hist.GetAckRanges()
|
||||
Expect(ackRanges).To(HaveLen(1))
|
||||
Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 4, Largest: 5}))
|
||||
})
|
||||
|
||||
It("gets multiple ACK ranges", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(5)
|
||||
hist.ReceivedPacket(6)
|
||||
hist.ReceivedPacket(1)
|
||||
hist.ReceivedPacket(11)
|
||||
hist.ReceivedPacket(10)
|
||||
hist.ReceivedPacket(2)
|
||||
ackRanges := hist.GetAckRanges()
|
||||
Expect(ackRanges).To(HaveLen(3))
|
||||
Expect(ackRanges[0]).To(Equal(wire.AckRange{Smallest: 10, Largest: 11}))
|
||||
Expect(ackRanges[1]).To(Equal(wire.AckRange{Smallest: 4, Largest: 6}))
|
||||
Expect(ackRanges[2]).To(Equal(wire.AckRange{Smallest: 1, Largest: 2}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Getting the highest ACK range", func() {
|
||||
It("returns the zero value if there are no ranges", func() {
|
||||
Expect(hist.GetHighestAckRange()).To(BeZero())
|
||||
})
|
||||
|
||||
It("gets a single ACK range", func() {
|
||||
hist.ReceivedPacket(4)
|
||||
hist.ReceivedPacket(5)
|
||||
Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 4, Largest: 5}))
|
||||
})
|
||||
|
||||
It("gets the highest of multiple ACK ranges", func() {
|
||||
hist.ReceivedPacket(3)
|
||||
hist.ReceivedPacket(6)
|
||||
hist.ReceivedPacket(7)
|
||||
Expect(hist.GetHighestAckRange()).To(Equal(wire.AckRange{Smallest: 6, Largest: 7}))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,36 @@
|
|||
package ackhandler
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
||||
// Returns a new slice with all non-retransmittable frames deleted.
|
||||
func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
|
||||
res := make([]wire.Frame, 0, len(fs))
|
||||
for _, f := range fs {
|
||||
if IsFrameRetransmittable(f) {
|
||||
res = append(res, f)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// IsFrameRetransmittable returns true if the frame should be retransmitted.
|
||||
func IsFrameRetransmittable(f wire.Frame) bool {
|
||||
switch f.(type) {
|
||||
case *wire.StopWaitingFrame:
|
||||
return false
|
||||
case *wire.AckFrame:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
|
||||
func HasRetransmittableFrames(fs []wire.Frame) bool {
|
||||
for _, f := range fs {
|
||||
if IsFrameRetransmittable(f) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("retransmittable frames", func() {
|
||||
for fl, el := range map[wire.Frame]bool{
|
||||
&wire.AckFrame{}: false,
|
||||
&wire.StopWaitingFrame{}: false,
|
||||
&wire.BlockedFrame{}: true,
|
||||
&wire.ConnectionCloseFrame{}: true,
|
||||
&wire.GoawayFrame{}: true,
|
||||
&wire.PingFrame{}: true,
|
||||
&wire.RstStreamFrame{}: true,
|
||||
&wire.StreamFrame{}: true,
|
||||
&wire.MaxDataFrame{}: true,
|
||||
&wire.MaxStreamDataFrame{}: true,
|
||||
} {
|
||||
f := fl
|
||||
e := el
|
||||
fName := reflect.ValueOf(f).Elem().Type().Name()
|
||||
|
||||
It("works for "+fName, func() {
|
||||
Expect(IsFrameRetransmittable(f)).To(Equal(e))
|
||||
})
|
||||
|
||||
It("stripping non-retransmittable frames works for "+fName, func() {
|
||||
s := []wire.Frame{f}
|
||||
if e {
|
||||
Expect(stripNonRetransmittableFrames(s)).To(Equal([]wire.Frame{f}))
|
||||
} else {
|
||||
Expect(stripNonRetransmittableFrames(s)).To(BeEmpty())
|
||||
}
|
||||
})
|
||||
|
||||
It("HasRetransmittableFrames works for "+fName, func() {
|
||||
Expect(HasRetransmittableFrames([]wire.Frame{f})).To(Equal(e))
|
||||
})
|
||||
}
|
||||
})
|
|
@ -0,0 +1,40 @@
|
|||
package ackhandler
|
||||
|
||||
import "fmt"
|
||||
|
||||
// The SendMode says what kind of packets can be sent.
|
||||
type SendMode uint8
|
||||
|
||||
const (
|
||||
// SendNone means that no packets should be sent
|
||||
SendNone SendMode = iota
|
||||
// SendAck means an ACK-only packet should be sent
|
||||
SendAck
|
||||
// SendRetransmission means that retransmissions should be sent
|
||||
SendRetransmission
|
||||
// SendRTO means that an RTO probe packet should be sent
|
||||
SendRTO
|
||||
// SendTLP means that a TLP probe packet should be sent
|
||||
SendTLP
|
||||
// SendAny means that any packet should be sent
|
||||
SendAny
|
||||
)
|
||||
|
||||
func (s SendMode) String() string {
|
||||
switch s {
|
||||
case SendNone:
|
||||
return "none"
|
||||
case SendAck:
|
||||
return "ack"
|
||||
case SendRetransmission:
|
||||
return "retransmission"
|
||||
case SendRTO:
|
||||
return "rto"
|
||||
case SendTLP:
|
||||
return "tlp"
|
||||
case SendAny:
|
||||
return "any"
|
||||
default:
|
||||
return fmt.Sprintf("invalid send mode: %d", s)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Send Mode", func() {
|
||||
It("has a string representation", func() {
|
||||
Expect(SendNone.String()).To(Equal("none"))
|
||||
Expect(SendAny.String()).To(Equal("any"))
|
||||
Expect(SendAck.String()).To(Equal("ack"))
|
||||
Expect(SendRTO.String()).To(Equal("rto"))
|
||||
Expect(SendTLP.String()).To(Equal("tlp"))
|
||||
Expect(SendRetransmission.String()).To(Equal("retransmission"))
|
||||
Expect(SendMode(123).String()).To(Equal("invalid send mode: 123"))
|
||||
})
|
||||
})
|
|
@ -0,0 +1,658 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
||||
// In fraction of an RTT.
|
||||
timeReorderingFraction = 1.0 / 8
|
||||
// defaultRTOTimeout is the RTO time on new connections
|
||||
defaultRTOTimeout = 500 * time.Millisecond
|
||||
// Minimum time in the future a tail loss probe alarm may be set for.
|
||||
minTPLTimeout = 10 * time.Millisecond
|
||||
// Maximum number of tail loss probes before an RTO fires.
|
||||
maxTLPs = 2
|
||||
// Minimum time in the future an RTO alarm may be set for.
|
||||
minRTOTimeout = 200 * time.Millisecond
|
||||
// maxRTOTimeout is the maximum RTO time
|
||||
maxRTOTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
type sentPacketHandler struct {
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
lastSentRetransmittablePacketTime time.Time
|
||||
lastSentHandshakePacketTime time.Time
|
||||
|
||||
nextPacketSendTime time.Time
|
||||
skippedPackets []protocol.PacketNumber
|
||||
|
||||
largestAcked protocol.PacketNumber
|
||||
largestReceivedPacketWithAck protocol.PacketNumber
|
||||
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
|
||||
// example: we send an ACK for packets 90-100 with packet number 20
|
||||
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
|
||||
lowestPacketNotConfirmedAcked protocol.PacketNumber
|
||||
largestSentBeforeRTO protocol.PacketNumber
|
||||
|
||||
packetHistory *sentPacketHistory
|
||||
stopWaitingManager stopWaitingManager
|
||||
|
||||
retransmissionQueue []*Packet
|
||||
|
||||
bytesInFlight protocol.ByteCount
|
||||
|
||||
congestion congestion.SendAlgorithm
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
handshakeComplete bool
|
||||
// The number of times the handshake packets have been retransmitted without receiving an ack.
|
||||
handshakeCount uint32
|
||||
|
||||
// The number of times a TLP has been sent without receiving an ack.
|
||||
tlpCount uint32
|
||||
allowTLP bool
|
||||
|
||||
// The number of times an RTO has been sent without receiving an ack.
|
||||
rtoCount uint32
|
||||
// The number of RTO probe packets that should be sent.
|
||||
numRTOs int
|
||||
|
||||
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
|
||||
lossTime time.Time
|
||||
|
||||
// The alarm timeout
|
||||
alarm time.Time
|
||||
|
||||
logger utils.Logger
|
||||
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
// NewSentPacketHandler creates a new sentPacketHandler
|
||||
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
|
||||
congestion := congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
rttStats,
|
||||
false, /* don't use reno since chromium doesn't (why?) */
|
||||
protocol.InitialCongestionWindow,
|
||||
protocol.DefaultMaxCongestionWindow,
|
||||
)
|
||||
|
||||
return &sentPacketHandler{
|
||||
packetHistory: newSentPacketHistory(),
|
||||
stopWaitingManager: stopWaitingManager{},
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
|
||||
if p := h.packetHistory.FirstOutstanding(); p != nil {
|
||||
return p.PacketNumber
|
||||
}
|
||||
return h.largestAcked + 1
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SetHandshakeComplete() {
|
||||
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
|
||||
var queue []*Packet
|
||||
for _, packet := range h.retransmissionQueue {
|
||||
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
|
||||
queue = append(queue, packet)
|
||||
}
|
||||
}
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
|
||||
handshakePackets = append(handshakePackets, p)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
for _, p := range handshakePackets {
|
||||
h.packetHistory.Remove(p.PacketNumber)
|
||||
}
|
||||
h.retransmissionQueue = queue
|
||||
h.handshakeComplete = true
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SentPacket(packet *Packet) {
|
||||
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
|
||||
h.packetHistory.SentPacket(packet)
|
||||
h.updateLossDetectionAlarm()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
|
||||
var p []*Packet
|
||||
for _, packet := range packets {
|
||||
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
|
||||
p = append(p, packet)
|
||||
}
|
||||
}
|
||||
h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf)
|
||||
h.updateLossDetectionAlarm()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
|
||||
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
|
||||
h.skippedPackets = append(h.skippedPackets, p)
|
||||
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
|
||||
h.skippedPackets = h.skippedPackets[1:]
|
||||
}
|
||||
}
|
||||
|
||||
h.lastSentPacketNumber = packet.PacketNumber
|
||||
|
||||
if len(packet.Frames) > 0 {
|
||||
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
|
||||
packet.largestAcked = ackFrame.LargestAcked()
|
||||
}
|
||||
}
|
||||
|
||||
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
|
||||
isRetransmittable := len(packet.Frames) != 0
|
||||
|
||||
if isRetransmittable {
|
||||
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
h.lastSentHandshakePacketTime = packet.SendTime
|
||||
}
|
||||
h.lastSentRetransmittablePacketTime = packet.SendTime
|
||||
packet.includedInBytesInFlight = true
|
||||
h.bytesInFlight += packet.Length
|
||||
packet.canBeRetransmitted = true
|
||||
if h.numRTOs > 0 {
|
||||
h.numRTOs--
|
||||
}
|
||||
h.allowTLP = false
|
||||
}
|
||||
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
|
||||
|
||||
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.SendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
|
||||
return isRetransmittable
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
|
||||
largestAcked := ackFrame.LargestAcked()
|
||||
if largestAcked > h.lastSentPacketNumber {
|
||||
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
|
||||
}
|
||||
|
||||
// duplicate or out of order ACK
|
||||
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
|
||||
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
|
||||
return nil
|
||||
}
|
||||
h.largestReceivedPacketWithAck = withPacketNumber
|
||||
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
|
||||
|
||||
if h.skippedPacketsAcked(ackFrame) {
|
||||
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
|
||||
}
|
||||
|
||||
if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
|
||||
h.congestion.MaybeExitSlowStart()
|
||||
}
|
||||
|
||||
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
priorInFlight := h.bytesInFlight
|
||||
for _, p := range ackedPackets {
|
||||
if encLevel < p.EncryptionLevel {
|
||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
|
||||
}
|
||||
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||
if p.largestAcked != 0 {
|
||||
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
|
||||
}
|
||||
if err := h.onPacketAcked(p, rcvTime); err != nil {
|
||||
return err
|
||||
}
|
||||
if p.includedInBytesInFlight {
|
||||
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil {
|
||||
return err
|
||||
}
|
||||
h.updateLossDetectionAlarm()
|
||||
|
||||
h.garbageCollectSkippedPackets()
|
||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
|
||||
return h.lowestPacketNotConfirmedAcked
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
|
||||
var ackedPackets []*Packet
|
||||
ackRangeIndex := 0
|
||||
lowestAcked := ackFrame.LowestAcked()
|
||||
largestAcked := ackFrame.LargestAcked()
|
||||
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
// Ignore packets below the lowest acked
|
||||
if p.PacketNumber < lowestAcked {
|
||||
return true, nil
|
||||
}
|
||||
// Break after largest acked is reached
|
||||
if p.PacketNumber > largestAcked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if ackFrame.HasMissingRanges() {
|
||||
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
|
||||
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 {
|
||||
ackRangeIndex++
|
||||
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
|
||||
}
|
||||
|
||||
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
|
||||
if p.PacketNumber > ackRange.Largest {
|
||||
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
|
||||
}
|
||||
ackedPackets = append(ackedPackets, p)
|
||||
}
|
||||
} else {
|
||||
ackedPackets = append(ackedPackets, p)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
if h.logger.Debug() && len(ackedPackets) > 0 {
|
||||
pns := make([]protocol.PacketNumber, len(ackedPackets))
|
||||
for i, p := range ackedPackets {
|
||||
pns[i] = p.PacketNumber
|
||||
}
|
||||
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
|
||||
}
|
||||
return ackedPackets, err
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
|
||||
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) updateLossDetectionAlarm() {
|
||||
// Cancel the alarm if no packets are outstanding
|
||||
if !h.packetHistory.HasOutstandingPackets() {
|
||||
h.alarm = time.Time{}
|
||||
return
|
||||
}
|
||||
|
||||
if h.packetHistory.HasOutstandingHandshakePackets() {
|
||||
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
|
||||
} else if !h.lossTime.IsZero() {
|
||||
// Early retransmit timer or time loss detection.
|
||||
h.alarm = h.lossTime
|
||||
} else {
|
||||
// RTO or TLP alarm
|
||||
alarmDuration := h.computeRTOTimeout()
|
||||
if h.tlpCount < maxTLPs {
|
||||
tlpAlarm := h.computeTLPTimeout()
|
||||
// if the RTO duration is shorter than the TLP duration, use the RTO duration
|
||||
alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm)
|
||||
}
|
||||
h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error {
|
||||
h.lossTime = time.Time{}
|
||||
|
||||
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
|
||||
|
||||
var lostPackets []*Packet
|
||||
h.packetHistory.Iterate(func(packet *Packet) (bool, error) {
|
||||
if packet.PacketNumber > h.largestAcked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
timeSinceSent := now.Sub(packet.SendTime)
|
||||
if timeSinceSent > delayUntilLost {
|
||||
lostPackets = append(lostPackets, packet)
|
||||
} else if h.lossTime.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent)
|
||||
}
|
||||
// Note: This conditional is only entered once per call
|
||||
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
if h.logger.Debug() && len(lostPackets) > 0 {
|
||||
pns := make([]protocol.PacketNumber, len(lostPackets))
|
||||
for i, p := range lostPackets {
|
||||
pns[i] = p.PacketNumber
|
||||
}
|
||||
h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns)
|
||||
}
|
||||
|
||||
for _, p := range lostPackets {
|
||||
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
|
||||
if p.includedInBytesInFlight {
|
||||
h.bytesInFlight -= p.Length
|
||||
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
|
||||
}
|
||||
if p.canBeRetransmitted {
|
||||
// queue the packet for retransmission, and report the loss to the congestion controller
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.packetHistory.Remove(p.PacketNumber)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) OnAlarm() error {
|
||||
// When all outstanding are acknowledged, the alarm is canceled in
|
||||
// updateLossDetectionAlarm. This doesn't reset the timer in the session though.
|
||||
// When OnAlarm is called, we therefore need to make sure that there are
|
||||
// actually packets outstanding.
|
||||
if h.packetHistory.HasOutstandingPackets() {
|
||||
if err := h.onVerifiedAlarm(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.updateLossDetectionAlarm()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) onVerifiedAlarm() error {
|
||||
var err error
|
||||
if h.packetHistory.HasOutstandingHandshakePackets() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in handshake mode. Handshake count: %d", h.handshakeCount)
|
||||
}
|
||||
h.handshakeCount++
|
||||
err = h.queueHandshakePacketsForRetransmission()
|
||||
} else if !h.lossTime.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", h.lossTime)
|
||||
}
|
||||
// Early retransmit or time loss detection
|
||||
err = h.detectLostPackets(time.Now(), h.bytesInFlight)
|
||||
} else if h.tlpCount < maxTLPs { // TLP
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in TLP mode. TLP count: %d", h.tlpCount)
|
||||
}
|
||||
h.allowTLP = true
|
||||
h.tlpCount++
|
||||
} else { // RTO
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in RTO mode. RTO count: %d", h.rtoCount)
|
||||
}
|
||||
if h.rtoCount == 0 {
|
||||
h.largestSentBeforeRTO = h.lastSentPacketNumber
|
||||
}
|
||||
h.rtoCount++
|
||||
h.numRTOs += 2
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
|
||||
return h.alarm
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
|
||||
// This happens if a packet and its retransmissions is acked in the same ACK.
|
||||
// As soon as we process the first one, this will remove all the retransmissions,
|
||||
// so we won't find the retransmitted packet number later.
|
||||
if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// only report the acking of this packet to the congestion controller if:
|
||||
// * it is a retransmittable packet
|
||||
// * this packet wasn't retransmitted yet
|
||||
if p.isRetransmission {
|
||||
// that the parent doesn't exist is expected to happen every time the original packet was already acked
|
||||
if parent := h.packetHistory.GetPacket(p.retransmissionOf); parent != nil {
|
||||
if len(parent.retransmittedAs) == 1 {
|
||||
parent.retransmittedAs = nil
|
||||
} else {
|
||||
// remove this packet from the slice of retransmission
|
||||
retransmittedAs := make([]protocol.PacketNumber, 0, len(parent.retransmittedAs)-1)
|
||||
for _, pn := range parent.retransmittedAs {
|
||||
if pn != p.PacketNumber {
|
||||
retransmittedAs = append(retransmittedAs, pn)
|
||||
}
|
||||
}
|
||||
parent.retransmittedAs = retransmittedAs
|
||||
}
|
||||
}
|
||||
}
|
||||
// this also applies to packets that have been retransmitted as probe packets
|
||||
if p.includedInBytesInFlight {
|
||||
h.bytesInFlight -= p.Length
|
||||
}
|
||||
if h.rtoCount > 0 {
|
||||
h.verifyRTO(p.PacketNumber)
|
||||
}
|
||||
if err := h.stopRetransmissionsFor(p); err != nil {
|
||||
return err
|
||||
}
|
||||
h.rtoCount = 0
|
||||
h.tlpCount = 0
|
||||
h.handshakeCount = 0
|
||||
return h.packetHistory.Remove(p.PacketNumber)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) stopRetransmissionsFor(p *Packet) error {
|
||||
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, r := range p.retransmittedAs {
|
||||
packet := h.packetHistory.GetPacket(r)
|
||||
if packet == nil {
|
||||
return fmt.Errorf("sent packet handler BUG: marking packet as not retransmittable %d (retransmission of %d) not found in history", r, p.PacketNumber)
|
||||
}
|
||||
h.stopRetransmissionsFor(packet)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) verifyRTO(pn protocol.PacketNumber) {
|
||||
if pn <= h.largestSentBeforeRTO {
|
||||
h.logger.Debugf("Spurious RTO detected. Received an ACK for %#x (largest sent before RTO: %#x)", pn, h.largestSentBeforeRTO)
|
||||
// Replace SRTT with latest_rtt and increase the variance to prevent
|
||||
// a spurious RTO from happening again.
|
||||
h.rttStats.ExpireSmoothedMetrics()
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("RTO verified. Received an ACK for %#x (largest sent before RTO: %#x", pn, h.largestSentBeforeRTO)
|
||||
h.congestion.OnRetransmissionTimeout(true)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
|
||||
if len(h.retransmissionQueue) == 0 {
|
||||
return nil
|
||||
}
|
||||
packet := h.retransmissionQueue[0]
|
||||
// Shift the slice and don't retain anything that isn't needed.
|
||||
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
|
||||
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
|
||||
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
|
||||
return packet
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
|
||||
if len(h.retransmissionQueue) == 0 {
|
||||
p := h.packetHistory.FirstOutstanding()
|
||||
if p == nil {
|
||||
return nil, errors.New("cannot dequeue a probe packet. No outstanding packets")
|
||||
}
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return h.DequeuePacketForRetransmission(), nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
|
||||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SendMode() SendMode {
|
||||
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
|
||||
|
||||
// Don't send any packets if we're keeping track of the maximum number of packets.
|
||||
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
|
||||
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
|
||||
// but still allow sending of retransmissions and ACKs.
|
||||
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
|
||||
}
|
||||
return SendNone
|
||||
}
|
||||
if h.allowTLP {
|
||||
return SendTLP
|
||||
}
|
||||
if h.numRTOs > 0 {
|
||||
return SendRTO
|
||||
}
|
||||
// Only send ACKs if we're congestion limited.
|
||||
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
|
||||
}
|
||||
return SendAck
|
||||
}
|
||||
// Send retransmissions first, if there are any.
|
||||
if len(h.retransmissionQueue) > 0 {
|
||||
return SendRetransmission
|
||||
}
|
||||
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
|
||||
}
|
||||
return SendAck
|
||||
}
|
||||
return SendAny
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) TimeUntilSend() time.Time {
|
||||
return h.nextPacketSendTime
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
||||
if h.numRTOs > 0 {
|
||||
// RTO probes should not be paced, but must be sent immediately.
|
||||
return h.numRTOs
|
||||
}
|
||||
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
|
||||
if delay == 0 || delay > protocol.MinPacingDelay {
|
||||
return 1
|
||||
}
|
||||
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
handshakePackets = append(handshakePackets, p)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
for _, p := range handshakePackets {
|
||||
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
|
||||
if err := h.queuePacketForRetransmission(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
|
||||
if !p.canBeRetransmitted {
|
||||
return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber)
|
||||
}
|
||||
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
|
||||
return err
|
||||
}
|
||||
h.retransmissionQueue = append(h.retransmissionQueue, p)
|
||||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
|
||||
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
|
||||
// exponential backoff
|
||||
// There's an implicit limit to this set by the handshake timeout.
|
||||
return duration << h.handshakeCount
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {
|
||||
// TODO(#1236): include the max_ack_delay
|
||||
return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
|
||||
var rto time.Duration
|
||||
rtt := h.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
rto = defaultRTOTimeout
|
||||
} else {
|
||||
rto = rtt + 4*h.rttStats.MeanDeviation()
|
||||
}
|
||||
rto = utils.MaxDuration(rto, minRTOTimeout)
|
||||
// Exponential backoff
|
||||
rto = rto << h.rtoCount
|
||||
return utils.MinDuration(rto, maxRTOTimeout)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
|
||||
for _, p := range h.skippedPackets {
|
||||
if ackFrame.AcksPacket(p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
|
||||
lowestUnacked := h.lowestUnacked()
|
||||
deleteIndex := 0
|
||||
for i, p := range h.skippedPackets {
|
||||
if p < lowestUnacked {
|
||||
deleteIndex = i + 1
|
||||
}
|
||||
}
|
||||
h.skippedPackets = h.skippedPackets[deleteIndex:]
|
||||
}
|
1080
vendor/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler_test.go
vendored
Normal file
1080
vendor/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler_test.go
vendored
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,168 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type sentPacketHistory struct {
|
||||
packetList *PacketList
|
||||
packetMap map[protocol.PacketNumber]*PacketElement
|
||||
|
||||
numOutstandingPackets int
|
||||
numOutstandingHandshakePackets int
|
||||
|
||||
firstOutstanding *PacketElement
|
||||
}
|
||||
|
||||
func newSentPacketHistory() *sentPacketHistory {
|
||||
return &sentPacketHistory{
|
||||
packetList: NewPacketList(),
|
||||
packetMap: make(map[protocol.PacketNumber]*PacketElement),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentPacket(p *Packet) {
|
||||
h.sentPacketImpl(p)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
|
||||
el := h.packetList.PushBack(*p)
|
||||
h.packetMap[p.PacketNumber] = el
|
||||
if h.firstOutstanding == nil {
|
||||
h.firstOutstanding = el
|
||||
}
|
||||
if p.canBeRetransmitted {
|
||||
h.numOutstandingPackets++
|
||||
if p.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
h.numOutstandingHandshakePackets++
|
||||
}
|
||||
}
|
||||
return el
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
|
||||
retransmission, ok := h.packetMap[retransmissionOf]
|
||||
// The retransmitted packet is not present anymore.
|
||||
// This can happen if it was acked in between dequeueing of the retransmission and sending.
|
||||
// Just treat the retransmissions as normal packets.
|
||||
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
|
||||
if !ok {
|
||||
for _, packet := range packets {
|
||||
h.sentPacketImpl(packet)
|
||||
}
|
||||
return
|
||||
}
|
||||
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
|
||||
for i, packet := range packets {
|
||||
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
|
||||
el := h.sentPacketImpl(packet)
|
||||
el.Value.isRetransmission = true
|
||||
el.Value.retransmissionOf = retransmissionOf
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
|
||||
if el, ok := h.packetMap[p]; ok {
|
||||
return &el.Value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Iterate iterates through all packets.
|
||||
// The callback must not modify the history.
|
||||
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
|
||||
cont := true
|
||||
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
|
||||
var err error
|
||||
cont, err = cb(&el.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FirstOutStanding returns the first outstanding packet.
|
||||
// It must not be modified (e.g. retransmitted).
|
||||
// Use DequeueFirstPacketForRetransmission() to retransmit it.
|
||||
func (h *sentPacketHistory) FirstOutstanding() *Packet {
|
||||
if h.firstOutstanding == nil {
|
||||
return nil
|
||||
}
|
||||
return &h.firstOutstanding.Value
|
||||
}
|
||||
|
||||
// QueuePacketForRetransmission marks a packet for retransmission.
|
||||
// A packet can only be queued once.
|
||||
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
|
||||
el, ok := h.packetMap[pn]
|
||||
if !ok {
|
||||
return fmt.Errorf("sent packet history: packet %d not found", pn)
|
||||
}
|
||||
if el.Value.canBeRetransmitted {
|
||||
h.numOutstandingPackets--
|
||||
if h.numOutstandingPackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
h.numOutstandingHandshakePackets--
|
||||
if h.numOutstandingHandshakePackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
}
|
||||
}
|
||||
el.Value.canBeRetransmitted = false
|
||||
if el == h.firstOutstanding {
|
||||
h.readjustFirstOutstanding()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
|
||||
// This is necessary every time the first outstanding packet is deleted or retransmitted.
|
||||
func (h *sentPacketHistory) readjustFirstOutstanding() {
|
||||
el := h.firstOutstanding.Next()
|
||||
for el != nil && !el.Value.canBeRetransmitted {
|
||||
el = el.Next()
|
||||
}
|
||||
h.firstOutstanding = el
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Len() int {
|
||||
return len(h.packetMap)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
|
||||
el, ok := h.packetMap[p]
|
||||
if !ok {
|
||||
return fmt.Errorf("packet %d not found in sent packet history", p)
|
||||
}
|
||||
if el == h.firstOutstanding {
|
||||
h.readjustFirstOutstanding()
|
||||
}
|
||||
if el.Value.canBeRetransmitted {
|
||||
h.numOutstandingPackets--
|
||||
if h.numOutstandingPackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
|
||||
h.numOutstandingHandshakePackets--
|
||||
if h.numOutstandingHandshakePackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
}
|
||||
}
|
||||
h.packetList.Remove(el)
|
||||
delete(h.packetMap, p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) HasOutstandingPackets() bool {
|
||||
return h.numOutstandingPackets > 0
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) HasOutstandingHandshakePackets() bool {
|
||||
return h.numOutstandingHandshakePackets > 0
|
||||
}
|
297
vendor/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history_test.go
vendored
Normal file
297
vendor/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history_test.go
vendored
Normal file
|
@ -0,0 +1,297 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("SentPacketHistory", func() {
|
||||
var hist *sentPacketHistory
|
||||
|
||||
expectInHistory := func(packetNumbers []protocol.PacketNumber) {
|
||||
ExpectWithOffset(1, hist.packetMap).To(HaveLen(len(packetNumbers)))
|
||||
ExpectWithOffset(1, hist.packetList.Len()).To(Equal(len(packetNumbers)))
|
||||
i := 0
|
||||
hist.Iterate(func(p *Packet) (bool, error) {
|
||||
pn := packetNumbers[i]
|
||||
ExpectWithOffset(1, p.PacketNumber).To(Equal(pn))
|
||||
ExpectWithOffset(1, hist.packetMap[pn].Value.PacketNumber).To(Equal(pn))
|
||||
i++
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
hist = newSentPacketHistory()
|
||||
})
|
||||
|
||||
It("saves sent packets", func() {
|
||||
hist.SentPacket(&Packet{PacketNumber: 1})
|
||||
hist.SentPacket(&Packet{PacketNumber: 3})
|
||||
hist.SentPacket(&Packet{PacketNumber: 4})
|
||||
expectInHistory([]protocol.PacketNumber{1, 3, 4})
|
||||
})
|
||||
|
||||
It("gets the length", func() {
|
||||
hist.SentPacket(&Packet{PacketNumber: 1})
|
||||
hist.SentPacket(&Packet{PacketNumber: 10})
|
||||
Expect(hist.Len()).To(Equal(2))
|
||||
})
|
||||
|
||||
Context("getting the first outstanding packet", func() {
|
||||
It("gets nil, if there are no packets", func() {
|
||||
Expect(hist.FirstOutstanding()).To(BeNil())
|
||||
})
|
||||
|
||||
It("gets the first outstanding packet", func() {
|
||||
hist.SentPacket(&Packet{PacketNumber: 2})
|
||||
hist.SentPacket(&Packet{PacketNumber: 3})
|
||||
front := hist.FirstOutstanding()
|
||||
Expect(front).ToNot(BeNil())
|
||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(2)))
|
||||
})
|
||||
|
||||
It("gets the second packet if the first one is retransmitted", func() {
|
||||
hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true})
|
||||
hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true})
|
||||
hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true})
|
||||
front := hist.FirstOutstanding()
|
||||
Expect(front).ToNot(BeNil())
|
||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
||||
// Queue the first packet for retransmission.
|
||||
// The first outstanding packet should now be 3.
|
||||
err := hist.MarkCannotBeRetransmitted(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
front = hist.FirstOutstanding()
|
||||
Expect(front).ToNot(BeNil())
|
||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(3)))
|
||||
})
|
||||
|
||||
It("gets the third packet if the first two are retransmitted", func() {
|
||||
hist.SentPacket(&Packet{PacketNumber: 1, canBeRetransmitted: true})
|
||||
hist.SentPacket(&Packet{PacketNumber: 3, canBeRetransmitted: true})
|
||||
hist.SentPacket(&Packet{PacketNumber: 4, canBeRetransmitted: true})
|
||||
front := hist.FirstOutstanding()
|
||||
Expect(front).ToNot(BeNil())
|
||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
||||
// Queue the second packet for retransmission.
|
||||
// The first outstanding packet should still be 3.
|
||||
err := hist.MarkCannotBeRetransmitted(3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
front = hist.FirstOutstanding()
|
||||
Expect(front).ToNot(BeNil())
|
||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(1)))
|
||||
// Queue the first packet for retransmission.
|
||||
// The first outstanding packet should still be 4.
|
||||
err = hist.MarkCannotBeRetransmitted(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
front = hist.FirstOutstanding()
|
||||
Expect(front).ToNot(BeNil())
|
||||
Expect(front.PacketNumber).To(Equal(protocol.PacketNumber(4)))
|
||||
})
|
||||
})
|
||||
|
||||
It("gets a packet by packet number", func() {
|
||||
p := &Packet{PacketNumber: 2}
|
||||
hist.SentPacket(p)
|
||||
Expect(hist.GetPacket(2)).To(Equal(p))
|
||||
})
|
||||
|
||||
It("returns nil if the packet doesn't exist", func() {
|
||||
Expect(hist.GetPacket(1337)).To(BeNil())
|
||||
})
|
||||
|
||||
It("removes packets", func() {
|
||||
hist.SentPacket(&Packet{PacketNumber: 1})
|
||||
hist.SentPacket(&Packet{PacketNumber: 4})
|
||||
hist.SentPacket(&Packet{PacketNumber: 8})
|
||||
err := hist.Remove(4)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectInHistory([]protocol.PacketNumber{1, 8})
|
||||
})
|
||||
|
||||
It("errors when trying to remove a non existing packet", func() {
|
||||
hist.SentPacket(&Packet{PacketNumber: 1})
|
||||
err := hist.Remove(2)
|
||||
Expect(err).To(MatchError("packet 2 not found in sent packet history"))
|
||||
})
|
||||
|
||||
Context("iterating", func() {
|
||||
BeforeEach(func() {
|
||||
hist.SentPacket(&Packet{PacketNumber: 10})
|
||||
hist.SentPacket(&Packet{PacketNumber: 14})
|
||||
hist.SentPacket(&Packet{PacketNumber: 18})
|
||||
})
|
||||
|
||||
It("iterates over all packets", func() {
|
||||
var iterations []protocol.PacketNumber
|
||||
err := hist.Iterate(func(p *Packet) (bool, error) {
|
||||
iterations = append(iterations, p.PacketNumber)
|
||||
return true, nil
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14, 18}))
|
||||
})
|
||||
|
||||
It("stops iterating", func() {
|
||||
var iterations []protocol.PacketNumber
|
||||
err := hist.Iterate(func(p *Packet) (bool, error) {
|
||||
iterations = append(iterations, p.PacketNumber)
|
||||
return p.PacketNumber != 14, nil
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14}))
|
||||
})
|
||||
|
||||
It("returns the error", func() {
|
||||
testErr := errors.New("test error")
|
||||
var iterations []protocol.PacketNumber
|
||||
err := hist.Iterate(func(p *Packet) (bool, error) {
|
||||
iterations = append(iterations, p.PacketNumber)
|
||||
if p.PacketNumber == 14 {
|
||||
return false, testErr
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(iterations).To(Equal([]protocol.PacketNumber{10, 14}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("retransmissions", func() {
|
||||
BeforeEach(func() {
|
||||
for i := protocol.PacketNumber(1); i <= 5; i++ {
|
||||
hist.SentPacket(&Packet{PacketNumber: i})
|
||||
}
|
||||
})
|
||||
|
||||
It("errors if the packet doesn't exist", func() {
|
||||
err := hist.MarkCannotBeRetransmitted(100)
|
||||
Expect(err).To(MatchError("sent packet history: packet 100 not found"))
|
||||
})
|
||||
|
||||
It("adds a sent packets as a retransmission", func() {
|
||||
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}}, 2)
|
||||
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13})
|
||||
Expect(hist.GetPacket(13).isRetransmission).To(BeTrue())
|
||||
Expect(hist.GetPacket(13).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
|
||||
Expect(hist.GetPacket(2).retransmittedAs).To(Equal([]protocol.PacketNumber{13}))
|
||||
})
|
||||
|
||||
It("adds multiple packets sent as a retransmission", func() {
|
||||
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}, {PacketNumber: 15}}, 2)
|
||||
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13, 15})
|
||||
Expect(hist.GetPacket(13).isRetransmission).To(BeTrue())
|
||||
Expect(hist.GetPacket(13).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
|
||||
Expect(hist.GetPacket(15).retransmissionOf).To(Equal(protocol.PacketNumber(2)))
|
||||
Expect(hist.GetPacket(2).retransmittedAs).To(Equal([]protocol.PacketNumber{13, 15}))
|
||||
})
|
||||
|
||||
It("adds a packet as a normal packet if the retransmitted packet doesn't exist", func() {
|
||||
hist.SentPacketsAsRetransmission([]*Packet{{PacketNumber: 13}}, 7)
|
||||
expectInHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 13})
|
||||
Expect(hist.GetPacket(13).isRetransmission).To(BeFalse())
|
||||
Expect(hist.GetPacket(13).retransmissionOf).To(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Context("outstanding packets", func() {
|
||||
It("says if it has outstanding handshake packets", func() {
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
||||
hist.SentPacket(&Packet{
|
||||
EncryptionLevel: protocol.EncryptionUnencrypted,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("says if it has outstanding packets", func() {
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
||||
hist.SentPacket(&Packet{
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
||||
Expect(hist.HasOutstandingPackets()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't consider non-retransmittable packets as outstanding", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
EncryptionLevel: protocol.EncryptionUnencrypted,
|
||||
})
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("accounts for deleted handshake packets", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 5,
|
||||
EncryptionLevel: protocol.EncryptionSecure,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
|
||||
err := hist.Remove(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("accounts for deleted packets", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 10,
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingPackets()).To(BeTrue())
|
||||
err := hist.Remove(10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("doesn't count handshake packets marked as non-retransmittable", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 5,
|
||||
EncryptionLevel: protocol.EncryptionUnencrypted,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
|
||||
err := hist.MarkCannotBeRetransmitted(5)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("doesn't count packets marked as non-retransmittable", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 10,
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingPackets()).To(BeTrue())
|
||||
err := hist.MarkCannotBeRetransmitted(10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("counts the number of packets", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 10,
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 11,
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
err := hist.Remove(11)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hist.HasOutstandingPackets()).To(BeTrue())
|
||||
err = hist.Remove(10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
|
@ -0,0 +1,43 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
|
||||
type stopWaitingManager struct {
|
||||
largestLeastUnackedSent protocol.PacketNumber
|
||||
nextLeastUnacked protocol.PacketNumber
|
||||
|
||||
lastStopWaitingFrame *wire.StopWaitingFrame
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
|
||||
if force {
|
||||
return s.lastStopWaitingFrame
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
s.largestLeastUnackedSent = s.nextLeastUnacked
|
||||
swf := &wire.StopWaitingFrame{
|
||||
LeastUnacked: s.nextLeastUnacked,
|
||||
}
|
||||
s.lastStopWaitingFrame = swf
|
||||
return swf
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
|
||||
largestAcked := ack.LargestAcked()
|
||||
if largestAcked >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = largestAcked + 1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) QueuedRetransmissionForPacketNumber(p protocol.PacketNumber) {
|
||||
if p >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = p + 1
|
||||
}
|
||||
}
|
55
vendor/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager_test.go
vendored
Normal file
55
vendor/lucas-clemente/quic-go/internal/ackhandler/stop_waiting_manager_test.go
vendored
Normal file
|
@ -0,0 +1,55 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("StopWaitingManager", func() {
|
||||
var manager *stopWaitingManager
|
||||
BeforeEach(func() {
|
||||
manager = &stopWaitingManager{}
|
||||
})
|
||||
|
||||
It("returns nil in the beginning", func() {
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
|
||||
Expect(manager.GetStopWaitingFrame(true)).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns a StopWaitingFrame, when a new ACK arrives", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
|
||||
})
|
||||
|
||||
It("does not decrease the LeastUnacked", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 9}}})
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
|
||||
})
|
||||
|
||||
It("does not send the same StopWaitingFrame twice", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
|
||||
})
|
||||
|
||||
It("gets the same StopWaitingFrame twice, if forced", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
|
||||
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
|
||||
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("increases the LeastUnacked when a retransmission is queued", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
manager.QueuedRetransmissionForPacketNumber(20)
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 21}))
|
||||
})
|
||||
|
||||
It("does not decrease the LeastUnacked when a retransmission is queued", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
manager.QueuedRetransmissionForPacketNumber(9)
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
|
||||
})
|
||||
})
|
|
@ -0,0 +1,22 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// Bandwidth of a connection
|
||||
type Bandwidth uint64
|
||||
|
||||
const (
|
||||
// BitsPerSecond is 1 bit per second
|
||||
BitsPerSecond Bandwidth = 1
|
||||
// BytesPerSecond is 1 byte per second
|
||||
BytesPerSecond = 8 * BitsPerSecond
|
||||
)
|
||||
|
||||
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
|
||||
func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth {
|
||||
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Bandwidth", func() {
|
||||
It("converts from time delta", func() {
|
||||
Expect(BandwidthFromDelta(1, time.Millisecond)).To(Equal(1000 * BytesPerSecond))
|
||||
})
|
||||
})
|
|
@ -0,0 +1,18 @@
|
|||
package congestion
|
||||
|
||||
import "time"
|
||||
|
||||
// A Clock returns the current time
|
||||
type Clock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
// DefaultClock implements the Clock interface using the Go stdlib clock.
|
||||
type DefaultClock struct{}
|
||||
|
||||
var _ Clock = DefaultClock{}
|
||||
|
||||
// Now gets the current time
|
||||
func (DefaultClock) Now() time.Time {
|
||||
return time.Now()
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCongestion(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Congestion Suite")
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// This cubic implementation is based on the one found in Chromiums's QUIC
|
||||
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
|
||||
|
||||
// Constants based on TCP defaults.
|
||||
// The following constants are in 2^10 fractions of a second instead of ms to
|
||||
// allow a 10 shift right to divide.
|
||||
|
||||
// 1024*1024^3 (first 1024 is from 0.100^3)
|
||||
// where 0.100 is 100 ms which is the scaling round trip time.
|
||||
const cubeScale = 40
|
||||
const cubeCongestionWindowScale = 410
|
||||
const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS
|
||||
|
||||
const defaultNumConnections = 2
|
||||
|
||||
// Default Cubic backoff factor
|
||||
const beta float32 = 0.7
|
||||
|
||||
// Additional backoff factor when loss occurs in the concave part of the Cubic
|
||||
// curve. This additional backoff factor is expected to give up bandwidth to
|
||||
// new concurrent flows and speed up convergence.
|
||||
const betaLastMax float32 = 0.85
|
||||
|
||||
// Cubic implements the cubic algorithm from TCP
|
||||
type Cubic struct {
|
||||
clock Clock
|
||||
|
||||
// Number of connections to simulate.
|
||||
numConnections int
|
||||
|
||||
// Time when this cycle started, after last loss event.
|
||||
epoch time.Time
|
||||
|
||||
// Max congestion window used just before last loss event.
|
||||
// Note: to improve fairness to other streams an additional back off is
|
||||
// applied to this value if the new value is below our latest value.
|
||||
lastMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
// Number of acked bytes since the cycle started (epoch).
|
||||
ackedBytesCount protocol.ByteCount
|
||||
|
||||
// TCP Reno equivalent congestion window in packets.
|
||||
estimatedTCPcongestionWindow protocol.ByteCount
|
||||
|
||||
// Origin point of cubic function.
|
||||
originPointCongestionWindow protocol.ByteCount
|
||||
|
||||
// Time to origin point of cubic function in 2^10 fractions of a second.
|
||||
timeToOriginPoint uint32
|
||||
|
||||
// Last congestion window in packets computed by cubic function.
|
||||
lastTargetCongestionWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
// NewCubic returns a new Cubic instance
|
||||
func NewCubic(clock Clock) *Cubic {
|
||||
c := &Cubic{
|
||||
clock: clock,
|
||||
numConnections: defaultNumConnections,
|
||||
}
|
||||
c.Reset()
|
||||
return c
|
||||
}
|
||||
|
||||
// Reset is called after a timeout to reset the cubic state
|
||||
func (c *Cubic) Reset() {
|
||||
c.epoch = time.Time{}
|
||||
c.lastMaxCongestionWindow = 0
|
||||
c.ackedBytesCount = 0
|
||||
c.estimatedTCPcongestionWindow = 0
|
||||
c.originPointCongestionWindow = 0
|
||||
c.timeToOriginPoint = 0
|
||||
c.lastTargetCongestionWindow = 0
|
||||
}
|
||||
|
||||
func (c *Cubic) alpha() float32 {
|
||||
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
|
||||
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
|
||||
// We derive the equivalent alpha for an N-connection emulation as:
|
||||
b := c.beta()
|
||||
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
|
||||
}
|
||||
|
||||
func (c *Cubic) beta() float32 {
|
||||
// kNConnectionBeta is the backoff factor after loss for our N-connection
|
||||
// emulation, which emulates the effective backoff of an ensemble of N
|
||||
// TCP-Reno connections on a single loss event. The effective multiplier is
|
||||
// computed as:
|
||||
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
func (c *Cubic) betaLastMax() float32 {
|
||||
// betaLastMax is the additional backoff factor after loss for our
|
||||
// N-connection emulation, which emulates the additional backoff of
|
||||
// an ensemble of N TCP-Reno connections on a single loss event. The
|
||||
// effective multiplier is computed as:
|
||||
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
||||
// the available congestion window. Resets Cubic state during quiescence.
|
||||
func (c *Cubic) OnApplicationLimited() {
|
||||
// When sender is not using the available congestion window, the window does
|
||||
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
|
||||
// using the entire window during the time since the beginning of the current
|
||||
// "epoch" (the end of the last loss recovery period). Since
|
||||
// application-limited periods break this assumption, we reset the epoch when
|
||||
// in such a period. This reset effectively freezes congestion window growth
|
||||
// through application-limited periods and allows Cubic growth to continue
|
||||
// when the entire window is being used.
|
||||
c.epoch = time.Time{}
|
||||
}
|
||||
|
||||
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
||||
// a loss event. Returns the new congestion window in packets. The new
|
||||
// congestion window is a multiplicative decrease of our current window.
|
||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
|
||||
if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow {
|
||||
// We never reached the old max, so assume we are competing with another
|
||||
// flow. Use our extra back off factor to allow the other flow to go up.
|
||||
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
|
||||
} else {
|
||||
c.lastMaxCongestionWindow = currentCongestionWindow
|
||||
}
|
||||
c.epoch = time.Time{} // Reset time.
|
||||
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
|
||||
}
|
||||
|
||||
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
||||
// Returns the new congestion window in packets. The new congestion window
|
||||
// follows a cubic function that depends on the time passed since last
|
||||
// packet loss.
|
||||
func (c *Cubic) CongestionWindowAfterAck(
|
||||
ackedBytes protocol.ByteCount,
|
||||
currentCongestionWindow protocol.ByteCount,
|
||||
delayMin time.Duration,
|
||||
eventTime time.Time,
|
||||
) protocol.ByteCount {
|
||||
c.ackedBytesCount += ackedBytes
|
||||
|
||||
if c.epoch.IsZero() {
|
||||
// First ACK after a loss event.
|
||||
c.epoch = eventTime // Start of epoch.
|
||||
c.ackedBytesCount = ackedBytes // Reset count.
|
||||
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
||||
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
||||
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
||||
c.timeToOriginPoint = 0
|
||||
c.originPointCongestionWindow = currentCongestionWindow
|
||||
} else {
|
||||
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
||||
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
||||
}
|
||||
}
|
||||
|
||||
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
||||
// the round trip time in account. This is done to allow us to use shift as a
|
||||
// divide operator.
|
||||
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
|
||||
|
||||
// Right-shifts of negative, signed numbers have implementation-dependent
|
||||
// behavior, so force the offset to be positive, as is done in the kernel.
|
||||
offset := int64(c.timeToOriginPoint) - elapsedTime
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale
|
||||
var targetCongestionWindow protocol.ByteCount
|
||||
if elapsedTime > int64(c.timeToOriginPoint) {
|
||||
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
||||
} else {
|
||||
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
||||
}
|
||||
// Limit the CWND increase to half the acked bytes.
|
||||
targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
||||
|
||||
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
||||
// time we ack an estimated tcp window of bytes. For small
|
||||
// congestion windows (less than 25), the formula below will
|
||||
// increase slightly slower than linearly per estimated tcp window
|
||||
// of bytes.
|
||||
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow))
|
||||
c.ackedBytesCount = 0
|
||||
|
||||
// We have a new cubic congestion window.
|
||||
c.lastTargetCongestionWindow = targetCongestionWindow
|
||||
|
||||
// Compute target congestion_window based on cubic target and estimated TCP
|
||||
// congestion_window, use highest (fastest).
|
||||
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
||||
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
||||
}
|
||||
return targetCongestionWindow
|
||||
}
|
||||
|
||||
// SetNumConnections sets the number of emulated connections
|
||||
func (c *Cubic) SetNumConnections(n int) {
|
||||
c.numConnections = n
|
||||
}
|
|
@ -0,0 +1,318 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBurstBytes = 3 * protocol.DefaultTCPMSS
|
||||
renoBeta float32 = 0.7 // Reno backoff factor.
|
||||
defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
|
||||
)
|
||||
|
||||
type cubicSender struct {
|
||||
hybridSlowStart HybridSlowStart
|
||||
prr PrrSender
|
||||
rttStats *RTTStats
|
||||
stats connectionStats
|
||||
cubic *Cubic
|
||||
|
||||
reno bool
|
||||
|
||||
// Track the largest packet that has been sent.
|
||||
largestSentPacketNumber protocol.PacketNumber
|
||||
|
||||
// Track the largest packet that has been acked.
|
||||
largestAckedPacketNumber protocol.PacketNumber
|
||||
|
||||
// Track the largest packet number outstanding when a CWND cutback occurs.
|
||||
largestSentAtLastCutback protocol.PacketNumber
|
||||
|
||||
// Whether the last loss event caused us to exit slowstart.
|
||||
// Used for stats collection of slowstartPacketsLost
|
||||
lastCutbackExitedSlowstart bool
|
||||
|
||||
// When true, exit slow start with large cutback of congestion window.
|
||||
slowStartLargeReduction bool
|
||||
|
||||
// Congestion window in packets.
|
||||
congestionWindow protocol.ByteCount
|
||||
|
||||
// Minimum congestion window in packets.
|
||||
minCongestionWindow protocol.ByteCount
|
||||
|
||||
// Maximum congestion window.
|
||||
maxCongestionWindow protocol.ByteCount
|
||||
|
||||
// Slow start congestion window in bytes, aka ssthresh.
|
||||
slowstartThreshold protocol.ByteCount
|
||||
|
||||
// Number of connections to simulate.
|
||||
numConnections int
|
||||
|
||||
// ACK counter for the Reno implementation.
|
||||
numAckedPackets uint64
|
||||
|
||||
initialCongestionWindow protocol.ByteCount
|
||||
initialMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
minSlowStartExitWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
var _ SendAlgorithm = &cubicSender{}
|
||||
var _ SendAlgorithmWithDebugInfo = &cubicSender{}
|
||||
|
||||
// NewCubicSender makes a new cubic sender
|
||||
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo {
|
||||
return &cubicSender{
|
||||
rttStats: rttStats,
|
||||
initialCongestionWindow: initialCongestionWindow,
|
||||
initialMaxCongestionWindow: initialMaxCongestionWindow,
|
||||
congestionWindow: initialCongestionWindow,
|
||||
minCongestionWindow: defaultMinimumCongestionWindow,
|
||||
slowstartThreshold: initialMaxCongestionWindow,
|
||||
maxCongestionWindow: initialMaxCongestionWindow,
|
||||
numConnections: defaultNumConnections,
|
||||
cubic: NewCubic(clock),
|
||||
reno: reno,
|
||||
}
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow())
|
||||
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
|
||||
delay = delay * 8 / 5
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketSent(
|
||||
sentTime time.Time,
|
||||
bytesInFlight protocol.ByteCount,
|
||||
packetNumber protocol.PacketNumber,
|
||||
bytes protocol.ByteCount,
|
||||
isRetransmittable bool,
|
||||
) {
|
||||
if !isRetransmittable {
|
||||
return
|
||||
}
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
c.prr.OnPacketSent(bytes)
|
||||
}
|
||||
c.largestSentPacketNumber = packetNumber
|
||||
c.hybridSlowStart.OnPacketSent(packetNumber)
|
||||
}
|
||||
|
||||
func (c *cubicSender) InRecovery() bool {
|
||||
return c.largestAckedPacketNumber <= c.largestSentAtLastCutback && c.largestAckedPacketNumber != 0
|
||||
}
|
||||
|
||||
func (c *cubicSender) InSlowStart() bool {
|
||||
return c.GetCongestionWindow() < c.GetSlowStartThreshold()
|
||||
}
|
||||
|
||||
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
|
||||
return c.congestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
|
||||
return c.slowstartThreshold
|
||||
}
|
||||
|
||||
func (c *cubicSender) ExitSlowstart() {
|
||||
c.slowstartThreshold = c.congestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
|
||||
return c.slowstartThreshold
|
||||
}
|
||||
|
||||
func (c *cubicSender) MaybeExitSlowStart() {
|
||||
if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
|
||||
c.ExitSlowstart()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketAcked(
|
||||
ackedPacketNumber protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||
if c.InRecovery() {
|
||||
// PRR is used when in recovery.
|
||||
c.prr.OnPacketAcked(ackedBytes)
|
||||
return
|
||||
}
|
||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
||||
if c.InSlowStart() {
|
||||
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketLost(
|
||||
packetNumber protocol.PacketNumber,
|
||||
lostBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
) {
|
||||
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
||||
// already sent should be treated as a single loss event, since it's expected.
|
||||
if packetNumber <= c.largestSentAtLastCutback {
|
||||
if c.lastCutbackExitedSlowstart {
|
||||
c.stats.slowstartPacketsLost++
|
||||
c.stats.slowstartBytesLost += lostBytes
|
||||
if c.slowStartLargeReduction {
|
||||
// Reduce congestion window by lost_bytes for every loss.
|
||||
c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
|
||||
c.slowstartThreshold = c.congestionWindow
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
c.lastCutbackExitedSlowstart = c.InSlowStart()
|
||||
if c.InSlowStart() {
|
||||
c.stats.slowstartPacketsLost++
|
||||
}
|
||||
|
||||
c.prr.OnPacketLost(priorInFlight)
|
||||
|
||||
// TODO(chromium): Separate out all of slow start into a separate class.
|
||||
if c.slowStartLargeReduction && c.InSlowStart() {
|
||||
if c.congestionWindow >= 2*c.initialCongestionWindow {
|
||||
c.minSlowStartExitWindow = c.congestionWindow / 2
|
||||
}
|
||||
c.congestionWindow = c.congestionWindow - protocol.DefaultTCPMSS
|
||||
} else if c.reno {
|
||||
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
|
||||
} else {
|
||||
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
||||
}
|
||||
if c.congestionWindow < c.minCongestionWindow {
|
||||
c.congestionWindow = c.minCongestionWindow
|
||||
}
|
||||
c.slowstartThreshold = c.congestionWindow
|
||||
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
||||
// reset packet count from congestion avoidance mode. We start
|
||||
// counting again when we're out of recovery.
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
|
||||
func (c *cubicSender) RenoBeta() float32 {
|
||||
// kNConnectionBeta is the backoff factor after loss for our N-connection
|
||||
// emulation, which emulates the effective backoff of an ensemble of N
|
||||
// TCP-Reno connections on a single loss event. The effective multiplier is
|
||||
// computed as:
|
||||
return (float32(c.numConnections) - 1. + renoBeta) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
||||
// represents, but quic has a separate ack for each packet.
|
||||
func (c *cubicSender) maybeIncreaseCwnd(
|
||||
ackedPacketNumber protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
// Do not increase the congestion window unless the sender is close to using
|
||||
// the current window.
|
||||
if !c.isCwndLimited(priorInFlight) {
|
||||
c.cubic.OnApplicationLimited()
|
||||
return
|
||||
}
|
||||
if c.congestionWindow >= c.maxCongestionWindow {
|
||||
return
|
||||
}
|
||||
if c.InSlowStart() {
|
||||
// TCP slow start, exponential growth, increase by one for each ACK.
|
||||
c.congestionWindow += protocol.DefaultTCPMSS
|
||||
return
|
||||
}
|
||||
// Congestion avoidance
|
||||
if c.reno {
|
||||
// Classic Reno congestion avoidance.
|
||||
c.numAckedPackets++
|
||||
// Divide by num_connections to smoothly increase the CWND at a faster
|
||||
// rate than conventional Reno.
|
||||
if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
|
||||
c.congestionWindow += protocol.DefaultTCPMSS
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
} else {
|
||||
c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
|
||||
congestionWindow := c.GetCongestionWindow()
|
||||
if bytesInFlight >= congestionWindow {
|
||||
return true
|
||||
}
|
||||
availableBytes := congestionWindow - bytesInFlight
|
||||
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
|
||||
return slowStartLimited || availableBytes <= maxBurstBytes
|
||||
}
|
||||
|
||||
// BandwidthEstimate returns the current bandwidth estimate
|
||||
func (c *cubicSender) BandwidthEstimate() Bandwidth {
|
||||
srtt := c.rttStats.SmoothedRTT()
|
||||
if srtt == 0 {
|
||||
// If we haven't measured an rtt, the bandwidth estimate is unknown.
|
||||
return 0
|
||||
}
|
||||
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
|
||||
}
|
||||
|
||||
// HybridSlowStart returns the hybrid slow start instance for testing
|
||||
func (c *cubicSender) HybridSlowStart() *HybridSlowStart {
|
||||
return &c.hybridSlowStart
|
||||
}
|
||||
|
||||
// SetNumEmulatedConnections sets the number of emulated connections
|
||||
func (c *cubicSender) SetNumEmulatedConnections(n int) {
|
||||
c.numConnections = utils.Max(n, 1)
|
||||
c.cubic.SetNumConnections(c.numConnections)
|
||||
}
|
||||
|
||||
// OnRetransmissionTimeout is called on an retransmission timeout
|
||||
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||
c.largestSentAtLastCutback = 0
|
||||
if !packetsRetransmitted {
|
||||
return
|
||||
}
|
||||
c.hybridSlowStart.Restart()
|
||||
c.cubic.Reset()
|
||||
c.slowstartThreshold = c.congestionWindow / 2
|
||||
c.congestionWindow = c.minCongestionWindow
|
||||
}
|
||||
|
||||
// OnConnectionMigration is called when the connection is migrated (?)
|
||||
func (c *cubicSender) OnConnectionMigration() {
|
||||
c.hybridSlowStart.Restart()
|
||||
c.prr = PrrSender{}
|
||||
c.largestSentPacketNumber = 0
|
||||
c.largestAckedPacketNumber = 0
|
||||
c.largestSentAtLastCutback = 0
|
||||
c.lastCutbackExitedSlowstart = false
|
||||
c.cubic.Reset()
|
||||
c.numAckedPackets = 0
|
||||
c.congestionWindow = c.initialCongestionWindow
|
||||
c.slowstartThreshold = c.initialMaxCongestionWindow
|
||||
c.maxCongestionWindow = c.initialMaxCongestionWindow
|
||||
}
|
||||
|
||||
// SetSlowStartLargeReduction allows enabling the SSLR experiment
|
||||
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
|
||||
c.slowStartLargeReduction = enabled
|
||||
}
|
|
@ -0,0 +1,640 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const initialCongestionWindowPackets = 10
|
||||
const defaultWindowTCP = protocol.ByteCount(initialCongestionWindowPackets) * protocol.DefaultTCPMSS
|
||||
|
||||
type mockClock time.Time
|
||||
|
||||
func (c *mockClock) Now() time.Time {
|
||||
return time.Time(*c)
|
||||
}
|
||||
|
||||
func (c *mockClock) Advance(d time.Duration) {
|
||||
*c = mockClock(time.Time(*c).Add(d))
|
||||
}
|
||||
|
||||
const MaxCongestionWindow protocol.ByteCount = 200 * protocol.DefaultTCPMSS
|
||||
|
||||
var _ = Describe("Cubic Sender", func() {
|
||||
var (
|
||||
sender SendAlgorithmWithDebugInfo
|
||||
clock mockClock
|
||||
bytesInFlight protocol.ByteCount
|
||||
packetNumber protocol.PacketNumber
|
||||
ackedPacketNumber protocol.PacketNumber
|
||||
rttStats *RTTStats
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
bytesInFlight = 0
|
||||
packetNumber = 1
|
||||
ackedPacketNumber = 0
|
||||
clock = mockClock{}
|
||||
rttStats = NewRTTStats()
|
||||
sender = NewCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*protocol.DefaultTCPMSS, MaxCongestionWindow)
|
||||
})
|
||||
|
||||
canSend := func() bool {
|
||||
return bytesInFlight < sender.GetCongestionWindow()
|
||||
}
|
||||
|
||||
SendAvailableSendWindowLen := func(packetLength protocol.ByteCount) int {
|
||||
packetsSent := 0
|
||||
for canSend() {
|
||||
sender.OnPacketSent(clock.Now(), bytesInFlight, packetNumber, packetLength, true)
|
||||
packetNumber++
|
||||
packetsSent++
|
||||
bytesInFlight += packetLength
|
||||
}
|
||||
return packetsSent
|
||||
}
|
||||
|
||||
// Normal is that TCP acks every other segment.
|
||||
AckNPackets := func(n int) {
|
||||
rttStats.UpdateRTT(60*time.Millisecond, 0, clock.Now())
|
||||
sender.MaybeExitSlowStart()
|
||||
for i := 0; i < n; i++ {
|
||||
ackedPacketNumber++
|
||||
sender.OnPacketAcked(ackedPacketNumber, protocol.DefaultTCPMSS, bytesInFlight, clock.Now())
|
||||
}
|
||||
bytesInFlight -= protocol.ByteCount(n) * protocol.DefaultTCPMSS
|
||||
clock.Advance(time.Millisecond)
|
||||
}
|
||||
|
||||
LoseNPacketsLen := func(n int, packetLength protocol.ByteCount) {
|
||||
for i := 0; i < n; i++ {
|
||||
ackedPacketNumber++
|
||||
sender.OnPacketLost(ackedPacketNumber, packetLength, bytesInFlight)
|
||||
}
|
||||
bytesInFlight -= protocol.ByteCount(n) * packetLength
|
||||
}
|
||||
|
||||
// Does not increment acked_packet_number_.
|
||||
LosePacket := func(number protocol.PacketNumber) {
|
||||
sender.OnPacketLost(number, protocol.DefaultTCPMSS, bytesInFlight)
|
||||
bytesInFlight -= protocol.DefaultTCPMSS
|
||||
}
|
||||
|
||||
SendAvailableSendWindow := func() int { return SendAvailableSendWindowLen(protocol.DefaultTCPMSS) }
|
||||
LoseNPackets := func(n int) { LoseNPacketsLen(n, protocol.DefaultTCPMSS) }
|
||||
|
||||
It("has the right values at startup", func() {
|
||||
// At startup make sure we are at the default.
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
||||
// Make sure we can send.
|
||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
||||
Expect(canSend()).To(BeTrue())
|
||||
// And that window is un-affected.
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
||||
|
||||
// Fill the send window with data, then verify that we can't send.
|
||||
SendAvailableSendWindow()
|
||||
Expect(canSend()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("paces", func() {
|
||||
clock.Advance(time.Hour)
|
||||
// Fill the send window with data, then verify that we can't send.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(1)
|
||||
delay := sender.TimeUntilSend(bytesInFlight)
|
||||
Expect(delay).ToNot(BeZero())
|
||||
Expect(delay).ToNot(Equal(utils.InfDuration))
|
||||
})
|
||||
|
||||
It("application limited slow start", func() {
|
||||
// Send exactly 10 packets and ensure the CWND ends at 14 packets.
|
||||
const numberOfAcks = 5
|
||||
// At startup make sure we can send.
|
||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
||||
// Make sure we can send.
|
||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
||||
|
||||
SendAvailableSendWindow()
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
AckNPackets(2)
|
||||
}
|
||||
bytesToSend := sender.GetCongestionWindow()
|
||||
// It's expected 2 acks will arrive when the bytes_in_flight are greater than
|
||||
// half the CWND.
|
||||
Expect(bytesToSend).To(Equal(defaultWindowTCP + protocol.DefaultTCPMSS*2*2))
|
||||
})
|
||||
|
||||
It("exponential slow start", func() {
|
||||
const numberOfAcks = 20
|
||||
// At startup make sure we can send.
|
||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
||||
Expect(sender.BandwidthEstimate()).To(BeZero())
|
||||
// Make sure we can send.
|
||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
||||
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
}
|
||||
cwnd := sender.GetCongestionWindow()
|
||||
Expect(cwnd).To(Equal(defaultWindowTCP + protocol.DefaultTCPMSS*2*numberOfAcks))
|
||||
Expect(sender.BandwidthEstimate()).To(Equal(BandwidthFromDelta(cwnd, rttStats.SmoothedRTT())))
|
||||
})
|
||||
|
||||
It("slow start packet loss", func() {
|
||||
sender.SetNumEmulatedConnections(1)
|
||||
const numberOfAcks = 10
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
}
|
||||
SendAvailableSendWindow()
|
||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Lose a packet to exit slow start.
|
||||
LoseNPackets(1)
|
||||
packetsInRecoveryWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
||||
|
||||
// We should now have fallen out of slow start with a reduced window.
|
||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Recovery phase. We need to ack every packet in the recovery window before
|
||||
// we exit recovery.
|
||||
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
||||
AckNPackets(int(packetsInRecoveryWindow))
|
||||
SendAvailableSendWindow()
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// We need to ack an entire window before we increase CWND by 1.
|
||||
AckNPackets(int(numberOfPacketsInWindow) - 2)
|
||||
SendAvailableSendWindow()
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Next ack should increase cwnd by 1.
|
||||
AckNPackets(1)
|
||||
expectedSendWindow += protocol.DefaultTCPMSS
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Now RTO and ensure slow start gets reset.
|
||||
Expect(sender.HybridSlowStart().Started()).To(BeTrue())
|
||||
sender.OnRetransmissionTimeout(true)
|
||||
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("slow start packet loss with large reduction", func() {
|
||||
sender.SetSlowStartLargeReduction(true)
|
||||
|
||||
sender.SetNumEmulatedConnections(1)
|
||||
const numberOfAcks = 10
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
}
|
||||
SendAvailableSendWindow()
|
||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Lose a packet to exit slow start. We should now have fallen out of
|
||||
// slow start with a window reduced by 1.
|
||||
LoseNPackets(1)
|
||||
expectedSendWindow -= protocol.DefaultTCPMSS
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Lose 5 packets in recovery and verify that congestion window is reduced
|
||||
// further.
|
||||
LoseNPackets(5)
|
||||
expectedSendWindow -= 5 * protocol.DefaultTCPMSS
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
packetsInRecoveryWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
||||
|
||||
// Recovery phase. We need to ack every packet in the recovery window before
|
||||
// we exit recovery.
|
||||
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
||||
AckNPackets(int(packetsInRecoveryWindow))
|
||||
SendAvailableSendWindow()
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// We need to ack the rest of the window before cwnd increases by 1.
|
||||
AckNPackets(int(numberOfPacketsInWindow - 1))
|
||||
SendAvailableSendWindow()
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Next ack should increase cwnd by 1.
|
||||
AckNPackets(1)
|
||||
expectedSendWindow += protocol.DefaultTCPMSS
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Now RTO and ensure slow start gets reset.
|
||||
Expect(sender.HybridSlowStart().Started()).To(BeTrue())
|
||||
sender.OnRetransmissionTimeout(true)
|
||||
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("slow start half packet loss with large reduction", func() {
|
||||
sender.SetSlowStartLargeReduction(true)
|
||||
|
||||
sender.SetNumEmulatedConnections(1)
|
||||
const numberOfAcks = 10
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
// Send our full send window in half sized packets.
|
||||
SendAvailableSendWindowLen(protocol.DefaultTCPMSS / 2)
|
||||
AckNPackets(2)
|
||||
}
|
||||
SendAvailableSendWindowLen(protocol.DefaultTCPMSS / 2)
|
||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Lose a packet to exit slow start. We should now have fallen out of
|
||||
// slow start with a window reduced by 1.
|
||||
LoseNPackets(1)
|
||||
expectedSendWindow -= protocol.DefaultTCPMSS
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Lose 10 packets in recovery and verify that congestion window is reduced
|
||||
// by 5 packets.
|
||||
LoseNPacketsLen(10, protocol.DefaultTCPMSS/2)
|
||||
expectedSendWindow -= 5 * protocol.DefaultTCPMSS
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
})
|
||||
|
||||
// this test doesn't work any more after introducing the pacing needed for QUIC
|
||||
PIt("no PRR when less than one packet in flight", func() {
|
||||
SendAvailableSendWindow()
|
||||
LoseNPackets(int(initialCongestionWindowPackets) - 1)
|
||||
AckNPackets(1)
|
||||
// PRR will allow 2 packets for every ack during recovery.
|
||||
Expect(SendAvailableSendWindow()).To(Equal(2))
|
||||
// Simulate abandoning all packets by supplying a bytes_in_flight of 0.
|
||||
// PRR should now allow a packet to be sent, even though prr's state
|
||||
// variables believe it has sent enough packets.
|
||||
Expect(sender.TimeUntilSend(0)).To(BeZero())
|
||||
})
|
||||
|
||||
It("slow start packet loss PRR", func() {
|
||||
sender.SetNumEmulatedConnections(1)
|
||||
// Test based on the first example in RFC6937.
|
||||
// Ack 10 packets in 5 acks to raise the CWND to 20, as in the example.
|
||||
const numberOfAcks = 5
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
}
|
||||
SendAvailableSendWindow()
|
||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
LoseNPackets(1)
|
||||
|
||||
// We should now have fallen out of slow start with a reduced window.
|
||||
sendWindowBeforeLoss := expectedSendWindow
|
||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Testing TCP proportional rate reduction.
|
||||
// We should send packets paced over the received acks for the remaining
|
||||
// outstanding packets. The number of packets before we exit recovery is the
|
||||
// original CWND minus the packet that has been lost and the one which
|
||||
// triggered the loss.
|
||||
remainingPacketsInRecovery := sendWindowBeforeLoss/protocol.DefaultTCPMSS - 2
|
||||
|
||||
for i := protocol.ByteCount(0); i < remainingPacketsInRecovery; i++ {
|
||||
AckNPackets(1)
|
||||
SendAvailableSendWindow()
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
}
|
||||
|
||||
// We need to ack another window before we increase CWND by 1.
|
||||
numberOfPacketsInWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
||||
for i := protocol.ByteCount(0); i < numberOfPacketsInWindow; i++ {
|
||||
AckNPackets(1)
|
||||
Expect(SendAvailableSendWindow()).To(Equal(1))
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
}
|
||||
|
||||
AckNPackets(1)
|
||||
expectedSendWindow += protocol.DefaultTCPMSS
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
})
|
||||
|
||||
It("slow start burst packet loss PRR", func() {
|
||||
sender.SetNumEmulatedConnections(1)
|
||||
// Test based on the second example in RFC6937, though we also implement
|
||||
// forward acknowledgements, so the first two incoming acks will trigger
|
||||
// PRR immediately.
|
||||
// Ack 20 packets in 10 acks to raise the CWND to 30.
|
||||
const numberOfAcks = 10
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
}
|
||||
SendAvailableSendWindow()
|
||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Lose one more than the congestion window reduction, so that after loss,
|
||||
// bytes_in_flight is lesser than the congestion window.
|
||||
sendWindowAfterLoss := protocol.ByteCount(renoBeta * float32(expectedSendWindow))
|
||||
numPacketsToLose := (expectedSendWindow-sendWindowAfterLoss)/protocol.DefaultTCPMSS + 1
|
||||
LoseNPackets(int(numPacketsToLose))
|
||||
// Immediately after the loss, ensure at least one packet can be sent.
|
||||
// Losses without subsequent acks can occur with timer based loss detection.
|
||||
Expect(sender.TimeUntilSend(bytesInFlight)).To(BeZero())
|
||||
AckNPackets(1)
|
||||
|
||||
// We should now have fallen out of slow start with a reduced window.
|
||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Only 2 packets should be allowed to be sent, per PRR-SSRB
|
||||
Expect(SendAvailableSendWindow()).To(Equal(2))
|
||||
|
||||
// Ack the next packet, which triggers another loss.
|
||||
LoseNPackets(1)
|
||||
AckNPackets(1)
|
||||
|
||||
// Send 2 packets to simulate PRR-SSRB.
|
||||
Expect(SendAvailableSendWindow()).To(Equal(2))
|
||||
|
||||
// Ack the next packet, which triggers another loss.
|
||||
LoseNPackets(1)
|
||||
AckNPackets(1)
|
||||
|
||||
// Send 2 packets to simulate PRR-SSRB.
|
||||
Expect(SendAvailableSendWindow()).To(Equal(2))
|
||||
|
||||
// Exit recovery and return to sending at the new rate.
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
AckNPackets(1)
|
||||
Expect(SendAvailableSendWindow()).To(Equal(1))
|
||||
}
|
||||
})
|
||||
|
||||
It("RTO congestion window", func() {
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
||||
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
|
||||
|
||||
// Expect the window to decrease to the minimum once the RTO fires
|
||||
// and slow start threshold to be set to 1/2 of the CWND.
|
||||
sender.OnRetransmissionTimeout(true)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(2 * protocol.DefaultTCPMSS))
|
||||
Expect(sender.SlowstartThreshold()).To(Equal(5 * protocol.DefaultTCPMSS))
|
||||
})
|
||||
|
||||
It("RTO congestion window no retransmission", func() {
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
||||
|
||||
// Expect the window to remain unchanged if the RTO fires but no
|
||||
// packets are retransmitted.
|
||||
sender.OnRetransmissionTimeout(false)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
||||
})
|
||||
|
||||
It("tcp cubic reset epoch on quiescence", func() {
|
||||
const maxCongestionWindow = 50
|
||||
const maxCongestionWindowBytes = maxCongestionWindow * protocol.DefaultTCPMSS
|
||||
sender = NewCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*protocol.DefaultTCPMSS, maxCongestionWindowBytes)
|
||||
|
||||
numSent := SendAvailableSendWindow()
|
||||
|
||||
// Make sure we fall out of slow start.
|
||||
savedCwnd := sender.GetCongestionWindow()
|
||||
LoseNPackets(1)
|
||||
Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow()))
|
||||
|
||||
// Ack the rest of the outstanding packets to get out of recovery.
|
||||
for i := 1; i < numSent; i++ {
|
||||
AckNPackets(1)
|
||||
}
|
||||
Expect(bytesInFlight).To(BeZero())
|
||||
|
||||
// Send a new window of data and ack all; cubic growth should occur.
|
||||
savedCwnd = sender.GetCongestionWindow()
|
||||
numSent = SendAvailableSendWindow()
|
||||
for i := 0; i < numSent; i++ {
|
||||
AckNPackets(1)
|
||||
}
|
||||
Expect(savedCwnd).To(BeNumerically("<", sender.GetCongestionWindow()))
|
||||
Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow()))
|
||||
Expect(bytesInFlight).To(BeZero())
|
||||
|
||||
// Quiescent time of 100 seconds
|
||||
clock.Advance(100 * time.Second)
|
||||
|
||||
// Send new window of data and ack one packet. Cubic epoch should have
|
||||
// been reset; ensure cwnd increase is not dramatic.
|
||||
savedCwnd = sender.GetCongestionWindow()
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(1)
|
||||
Expect(savedCwnd).To(BeNumerically("~", sender.GetCongestionWindow(), protocol.DefaultTCPMSS))
|
||||
Expect(maxCongestionWindowBytes).To(BeNumerically(">", sender.GetCongestionWindow()))
|
||||
})
|
||||
|
||||
It("multiple losses in one window", func() {
|
||||
SendAvailableSendWindow()
|
||||
initialWindow := sender.GetCongestionWindow()
|
||||
LosePacket(ackedPacketNumber + 1)
|
||||
postLossWindow := sender.GetCongestionWindow()
|
||||
Expect(initialWindow).To(BeNumerically(">", postLossWindow))
|
||||
LosePacket(ackedPacketNumber + 3)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow))
|
||||
LosePacket(packetNumber - 1)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(postLossWindow))
|
||||
|
||||
// Lose a later packet and ensure the window decreases.
|
||||
LosePacket(packetNumber)
|
||||
Expect(postLossWindow).To(BeNumerically(">", sender.GetCongestionWindow()))
|
||||
})
|
||||
|
||||
It("2 connection congestion avoidance at end of recovery", func() {
|
||||
sender.SetNumEmulatedConnections(2)
|
||||
// Ack 10 packets in 5 acks to raise the CWND to 20.
|
||||
const numberOfAcks = 5
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
}
|
||||
SendAvailableSendWindow()
|
||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
LoseNPackets(1)
|
||||
|
||||
// We should now have fallen out of slow start with a reduced window.
|
||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * sender.RenoBeta())
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// No congestion window growth should occur in recovery phase, i.e., until the
|
||||
// currently outstanding 20 packets are acked.
|
||||
for i := 0; i < 10; i++ {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
Expect(sender.InRecovery()).To(BeTrue())
|
||||
AckNPackets(2)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
}
|
||||
Expect(sender.InRecovery()).To(BeFalse())
|
||||
|
||||
// Out of recovery now. Congestion window should not grow for half an RTT.
|
||||
packetsInSendWindow := expectedSendWindow / protocol.DefaultTCPMSS
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(int(packetsInSendWindow/2 - 2))
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Next ack should increase congestion window by 1MSS.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
expectedSendWindow += protocol.DefaultTCPMSS
|
||||
packetsInSendWindow++
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Congestion window should remain steady again for half an RTT.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(int(packetsInSendWindow/2 - 1))
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Next ack should cause congestion window to grow by 1MSS.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
expectedSendWindow += protocol.DefaultTCPMSS
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
})
|
||||
|
||||
It("1 connection congestion avoidance at end of recovery", func() {
|
||||
sender.SetNumEmulatedConnections(1)
|
||||
// Ack 10 packets in 5 acks to raise the CWND to 20.
|
||||
const numberOfAcks = 5
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
}
|
||||
SendAvailableSendWindow()
|
||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
LoseNPackets(1)
|
||||
|
||||
// We should now have fallen out of slow start with a reduced window.
|
||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// No congestion window growth should occur in recovery phase, i.e., until the
|
||||
// currently outstanding 20 packets are acked.
|
||||
for i := 0; i < 10; i++ {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
Expect(sender.InRecovery()).To(BeTrue())
|
||||
AckNPackets(2)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
}
|
||||
Expect(sender.InRecovery()).To(BeFalse())
|
||||
|
||||
// Out of recovery now. Congestion window should not grow during RTT.
|
||||
for i := protocol.ByteCount(0); i < expectedSendWindow/protocol.DefaultTCPMSS-2; i += 2 {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
}
|
||||
|
||||
// Next ack should cause congestion window to grow by 1MSS.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
expectedSendWindow += protocol.DefaultTCPMSS
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
})
|
||||
|
||||
It("reset after connection migration", func() {
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
||||
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
|
||||
|
||||
// Starts with slow start.
|
||||
sender.SetNumEmulatedConnections(1)
|
||||
const numberOfAcks = 10
|
||||
for i := 0; i < numberOfAcks; i++ {
|
||||
// Send our full send window.
|
||||
SendAvailableSendWindow()
|
||||
AckNPackets(2)
|
||||
}
|
||||
SendAvailableSendWindow()
|
||||
expectedSendWindow := defaultWindowTCP + (protocol.DefaultTCPMSS * 2 * numberOfAcks)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Loses a packet to exit slow start.
|
||||
LoseNPackets(1)
|
||||
|
||||
// We should now have fallen out of slow start with a reduced window. Slow
|
||||
// start threshold is also updated.
|
||||
expectedSendWindow = protocol.ByteCount(float32(expectedSendWindow) * renoBeta)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(expectedSendWindow))
|
||||
Expect(sender.SlowstartThreshold()).To(Equal(expectedSendWindow))
|
||||
|
||||
// Resets cwnd and slow start threshold on connection migrations.
|
||||
sender.OnConnectionMigration()
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(defaultWindowTCP))
|
||||
Expect(sender.SlowstartThreshold()).To(Equal(MaxCongestionWindow))
|
||||
Expect(sender.HybridSlowStart().Started()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("default max cwnd", func() {
|
||||
sender = NewCubicSender(&clock, rttStats, true /*reno*/, initialCongestionWindowPackets*protocol.DefaultTCPMSS, protocol.DefaultMaxCongestionWindow)
|
||||
|
||||
defaultMaxCongestionWindowPackets := protocol.DefaultMaxCongestionWindow / protocol.DefaultTCPMSS
|
||||
for i := 1; i < int(defaultMaxCongestionWindowPackets); i++ {
|
||||
sender.MaybeExitSlowStart()
|
||||
sender.OnPacketAcked(protocol.PacketNumber(i), 1350, sender.GetCongestionWindow(), clock.Now())
|
||||
}
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(protocol.DefaultMaxCongestionWindow))
|
||||
})
|
||||
|
||||
It("limit cwnd increase in congestion avoidance", func() {
|
||||
// Enable Cubic.
|
||||
sender = NewCubicSender(&clock, rttStats, false, initialCongestionWindowPackets*protocol.DefaultTCPMSS, MaxCongestionWindow)
|
||||
numSent := SendAvailableSendWindow()
|
||||
|
||||
// Make sure we fall out of slow start.
|
||||
savedCwnd := sender.GetCongestionWindow()
|
||||
LoseNPackets(1)
|
||||
Expect(savedCwnd).To(BeNumerically(">", sender.GetCongestionWindow()))
|
||||
|
||||
// Ack the rest of the outstanding packets to get out of recovery.
|
||||
for i := 1; i < numSent; i++ {
|
||||
AckNPackets(1)
|
||||
}
|
||||
Expect(bytesInFlight).To(BeZero())
|
||||
|
||||
savedCwnd = sender.GetCongestionWindow()
|
||||
SendAvailableSendWindow()
|
||||
|
||||
// Ack packets until the CWND increases.
|
||||
for sender.GetCongestionWindow() == savedCwnd {
|
||||
AckNPackets(1)
|
||||
SendAvailableSendWindow()
|
||||
}
|
||||
// Bytes in flight may be larger than the CWND if the CWND isn't an exact
|
||||
// multiple of the packet sizes being sent.
|
||||
Expect(bytesInFlight).To(BeNumerically(">=", sender.GetCongestionWindow()))
|
||||
savedCwnd = sender.GetCongestionWindow()
|
||||
|
||||
// Advance time 2 seconds waiting for an ack.
|
||||
clock.Advance(2 * time.Second)
|
||||
|
||||
// Ack two packets. The CWND should increase by only one packet.
|
||||
AckNPackets(2)
|
||||
Expect(sender.GetCongestionWindow()).To(Equal(savedCwnd + protocol.DefaultTCPMSS))
|
||||
})
|
||||
})
|
|
@ -0,0 +1,236 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const numConnections uint32 = 2
|
||||
const nConnectionBeta float32 = (float32(numConnections) - 1 + beta) / float32(numConnections)
|
||||
const nConnectionBetaLastMax float32 = (float32(numConnections) - 1 + betaLastMax) / float32(numConnections)
|
||||
const nConnectionAlpha float32 = 3 * float32(numConnections) * float32(numConnections) * (1 - nConnectionBeta) / (1 + nConnectionBeta)
|
||||
const maxCubicTimeInterval = 30 * time.Millisecond
|
||||
|
||||
var _ = Describe("Cubic", func() {
|
||||
var (
|
||||
clock mockClock
|
||||
cubic *Cubic
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
clock = mockClock{}
|
||||
cubic = NewCubic(&clock)
|
||||
})
|
||||
|
||||
renoCwnd := func(currentCwnd protocol.ByteCount) protocol.ByteCount {
|
||||
return currentCwnd + protocol.ByteCount(float32(protocol.DefaultTCPMSS)*nConnectionAlpha*float32(protocol.DefaultTCPMSS)/float32(currentCwnd))
|
||||
}
|
||||
|
||||
cubicConvexCwnd := func(initialCwnd protocol.ByteCount, rtt, elapsedTime time.Duration) protocol.ByteCount {
|
||||
offset := protocol.ByteCount((elapsedTime+rtt)/time.Microsecond) << 10 / 1000000
|
||||
deltaCongestionWindow := 410 * offset * offset * offset * protocol.DefaultTCPMSS >> 40
|
||||
return initialCwnd + deltaCongestionWindow
|
||||
}
|
||||
|
||||
It("works above origin (with tighter bounds)", func() {
|
||||
// Convex growth.
|
||||
const rttMin = 100 * time.Millisecond
|
||||
const rttMinS = float32(rttMin/time.Millisecond) / 1000.0
|
||||
currentCwnd := 10 * protocol.DefaultTCPMSS
|
||||
initialCwnd := currentCwnd
|
||||
|
||||
clock.Advance(time.Millisecond)
|
||||
initialTime := clock.Now()
|
||||
expectedFirstCwnd := renoCwnd(currentCwnd)
|
||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, initialTime)
|
||||
Expect(expectedFirstCwnd).To(Equal(currentCwnd))
|
||||
|
||||
// Normal TCP phase.
|
||||
// The maximum number of expected reno RTTs can be calculated by
|
||||
// finding the point where the cubic curve and the reno curve meet.
|
||||
maxRenoRtts := int(math.Sqrt(float64(nConnectionAlpha/(0.4*rttMinS*rttMinS*rttMinS))) - 2)
|
||||
for i := 0; i < maxRenoRtts; i++ {
|
||||
// Alternatively, we expect it to increase by one, every time we
|
||||
// receive current_cwnd/Alpha acks back. (This is another way of
|
||||
// saying we expect cwnd to increase by approximately Alpha once
|
||||
// we receive current_cwnd number ofacks back).
|
||||
numAcksThisEpoch := int(float32(currentCwnd/protocol.DefaultTCPMSS) / nConnectionAlpha)
|
||||
|
||||
initialCwndThisEpoch := currentCwnd
|
||||
for n := 0; n < numAcksThisEpoch; n++ {
|
||||
// Call once per ACK.
|
||||
expectedNextCwnd := renoCwnd(currentCwnd)
|
||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
Expect(currentCwnd).To(Equal(expectedNextCwnd))
|
||||
}
|
||||
// Our byte-wise Reno implementation is an estimate. We expect
|
||||
// the cwnd to increase by approximately one MSS every
|
||||
// cwnd/kDefaultTCPMSS/Alpha acks, but it may be off by as much as
|
||||
// half a packet for smaller values of current_cwnd.
|
||||
cwndChangeThisEpoch := currentCwnd - initialCwndThisEpoch
|
||||
Expect(cwndChangeThisEpoch).To(BeNumerically("~", protocol.DefaultTCPMSS, protocol.DefaultTCPMSS/2))
|
||||
clock.Advance(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
for i := 0; i < 54; i++ {
|
||||
maxAcksThisEpoch := currentCwnd / protocol.DefaultTCPMSS
|
||||
interval := time.Duration(100*1000/maxAcksThisEpoch) * time.Microsecond
|
||||
for n := 0; n < int(maxAcksThisEpoch); n++ {
|
||||
clock.Advance(interval)
|
||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
|
||||
// If we allow per-ack updates, every update is a small cubic update.
|
||||
Expect(currentCwnd).To(Equal(expectedCwnd))
|
||||
}
|
||||
}
|
||||
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
|
||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
Expect(currentCwnd).To(Equal(expectedCwnd))
|
||||
})
|
||||
|
||||
It("works above the origin with fine grained cubing", func() {
|
||||
// Start the test with an artificially large cwnd to prevent Reno
|
||||
// from over-taking cubic.
|
||||
currentCwnd := 1000 * protocol.DefaultTCPMSS
|
||||
initialCwnd := currentCwnd
|
||||
rttMin := 100 * time.Millisecond
|
||||
clock.Advance(time.Millisecond)
|
||||
initialTime := clock.Now()
|
||||
|
||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
clock.Advance(600 * time.Millisecond)
|
||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
|
||||
// We expect the algorithm to perform only non-zero, fine-grained cubic
|
||||
// increases on every ack in this case.
|
||||
for i := 0; i < 100; i++ {
|
||||
clock.Advance(10 * time.Millisecond)
|
||||
expectedCwnd := cubicConvexCwnd(initialCwnd, rttMin, clock.Now().Sub(initialTime))
|
||||
nextCwnd := cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
// Make sure we are performing cubic increases.
|
||||
Expect(nextCwnd).To(Equal(expectedCwnd))
|
||||
// Make sure that these are non-zero, less-than-packet sized increases.
|
||||
Expect(nextCwnd).To(BeNumerically(">", currentCwnd))
|
||||
cwndDelta := nextCwnd - currentCwnd
|
||||
Expect(protocol.DefaultTCPMSS / 10).To(BeNumerically(">", cwndDelta))
|
||||
currentCwnd = nextCwnd
|
||||
}
|
||||
})
|
||||
|
||||
It("handles per ack updates", func() {
|
||||
// Start the test with a large cwnd and RTT, to force the first
|
||||
// increase to be a cubic increase.
|
||||
initialCwndPackets := 150
|
||||
currentCwnd := protocol.ByteCount(initialCwndPackets) * protocol.DefaultTCPMSS
|
||||
rttMin := 350 * time.Millisecond
|
||||
|
||||
// Initialize the epoch
|
||||
clock.Advance(time.Millisecond)
|
||||
// Keep track of the growth of the reno-equivalent cwnd.
|
||||
rCwnd := renoCwnd(currentCwnd)
|
||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
initialCwnd := currentCwnd
|
||||
|
||||
// Simulate the return of cwnd packets in less than
|
||||
// MaxCubicInterval() time.
|
||||
maxAcks := int(float32(initialCwndPackets) / nConnectionAlpha)
|
||||
interval := maxCubicTimeInterval / time.Duration(maxAcks+1)
|
||||
|
||||
// In this scenario, the first increase is dictated by the cubic
|
||||
// equation, but it is less than one byte, so the cwnd doesn't
|
||||
// change. Normally, without per-ack increases, any cwnd plateau
|
||||
// will cause the cwnd to be pinned for MaxCubicTimeInterval(). If
|
||||
// we enable per-ack updates, the cwnd will continue to grow,
|
||||
// regardless of the temporary plateau.
|
||||
clock.Advance(interval)
|
||||
rCwnd = renoCwnd(rCwnd)
|
||||
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(currentCwnd))
|
||||
for i := 1; i < maxAcks; i++ {
|
||||
clock.Advance(interval)
|
||||
nextCwnd := cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
rCwnd = renoCwnd(rCwnd)
|
||||
// The window shoud increase on every ack.
|
||||
Expect(nextCwnd).To(BeNumerically(">", currentCwnd))
|
||||
Expect(nextCwnd).To(Equal(rCwnd))
|
||||
currentCwnd = nextCwnd
|
||||
}
|
||||
|
||||
// After all the acks are returned from the epoch, we expect the
|
||||
// cwnd to have increased by nearly one packet. (Not exactly one
|
||||
// packet, because our byte-wise Reno algorithm is always a slight
|
||||
// under-estimation). Without per-ack updates, the current_cwnd
|
||||
// would otherwise be unchanged.
|
||||
minimumExpectedIncrease := protocol.DefaultTCPMSS * 9 / 10
|
||||
Expect(currentCwnd).To(BeNumerically(">", initialCwnd+minimumExpectedIncrease))
|
||||
})
|
||||
|
||||
It("handles loss events", func() {
|
||||
rttMin := 100 * time.Millisecond
|
||||
currentCwnd := 422 * protocol.DefaultTCPMSS
|
||||
expectedCwnd := renoCwnd(currentCwnd)
|
||||
// Initialize the state.
|
||||
clock.Advance(time.Millisecond)
|
||||
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd))
|
||||
|
||||
// On the first loss, the last max congestion window is set to the
|
||||
// congestion window before the loss.
|
||||
preLossCwnd := currentCwnd
|
||||
Expect(cubic.lastMaxCongestionWindow).To(BeZero())
|
||||
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
|
||||
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
|
||||
Expect(cubic.lastMaxCongestionWindow).To(Equal(preLossCwnd))
|
||||
currentCwnd = expectedCwnd
|
||||
|
||||
// On the second loss, the current congestion window has not yet
|
||||
// reached the last max congestion window. The last max congestion
|
||||
// window will be reduced by an additional backoff factor to allow
|
||||
// for competition.
|
||||
preLossCwnd = currentCwnd
|
||||
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
|
||||
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
|
||||
currentCwnd = expectedCwnd
|
||||
Expect(preLossCwnd).To(BeNumerically(">", cubic.lastMaxCongestionWindow))
|
||||
expectedLastMax := protocol.ByteCount(float32(preLossCwnd) * nConnectionBetaLastMax)
|
||||
Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax))
|
||||
Expect(expectedCwnd).To(BeNumerically("<", cubic.lastMaxCongestionWindow))
|
||||
// Simulate an increase, and check that we are below the origin.
|
||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
Expect(cubic.lastMaxCongestionWindow).To(BeNumerically(">", currentCwnd))
|
||||
|
||||
// On the final loss, simulate the condition where the congestion
|
||||
// window had a chance to grow nearly to the last congestion window.
|
||||
currentCwnd = cubic.lastMaxCongestionWindow - 1
|
||||
preLossCwnd = currentCwnd
|
||||
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
|
||||
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
|
||||
expectedLastMax = preLossCwnd
|
||||
Expect(cubic.lastMaxCongestionWindow).To(Equal(expectedLastMax))
|
||||
})
|
||||
|
||||
It("works below origin", func() {
|
||||
// Concave growth.
|
||||
rttMin := 100 * time.Millisecond
|
||||
currentCwnd := 422 * protocol.DefaultTCPMSS
|
||||
expectedCwnd := renoCwnd(currentCwnd)
|
||||
// Initialize the state.
|
||||
clock.Advance(time.Millisecond)
|
||||
Expect(cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())).To(Equal(expectedCwnd))
|
||||
|
||||
expectedCwnd = protocol.ByteCount(float32(currentCwnd) * nConnectionBeta)
|
||||
Expect(cubic.CongestionWindowAfterPacketLoss(currentCwnd)).To(Equal(expectedCwnd))
|
||||
currentCwnd = expectedCwnd
|
||||
// First update after loss to initialize the epoch.
|
||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
// Cubic phase.
|
||||
for i := 0; i < 40; i++ {
|
||||
clock.Advance(100 * time.Millisecond)
|
||||
currentCwnd = cubic.CongestionWindowAfterAck(protocol.DefaultTCPMSS, currentCwnd, rttMin, clock.Now())
|
||||
}
|
||||
expectedCwnd = 553632
|
||||
Expect(currentCwnd).To(Equal(expectedCwnd))
|
||||
})
|
||||
})
|
|
@ -0,0 +1,111 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// Note(pwestin): the magic clamping numbers come from the original code in
|
||||
// tcp_cubic.c.
|
||||
const hybridStartLowWindow = protocol.ByteCount(16)
|
||||
|
||||
// Number of delay samples for detecting the increase of delay.
|
||||
const hybridStartMinSamples = uint32(8)
|
||||
|
||||
// Exit slow start if the min rtt has increased by more than 1/8th.
|
||||
const hybridStartDelayFactorExp = 3 // 2^3 = 8
|
||||
// The original paper specifies 2 and 8ms, but those have changed over time.
|
||||
const hybridStartDelayMinThresholdUs = int64(4000)
|
||||
const hybridStartDelayMaxThresholdUs = int64(16000)
|
||||
|
||||
// HybridSlowStart implements the TCP hybrid slow start algorithm
|
||||
type HybridSlowStart struct {
|
||||
endPacketNumber protocol.PacketNumber
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
started bool
|
||||
currentMinRTT time.Duration
|
||||
rttSampleCount uint32
|
||||
hystartFound bool
|
||||
}
|
||||
|
||||
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
|
||||
func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) {
|
||||
s.endPacketNumber = lastSent
|
||||
s.currentMinRTT = 0
|
||||
s.rttSampleCount = 0
|
||||
s.started = true
|
||||
}
|
||||
|
||||
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
|
||||
func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool {
|
||||
return s.endPacketNumber < ack
|
||||
}
|
||||
|
||||
// ShouldExitSlowStart should be called on every new ack frame, since a new
|
||||
// RTT measurement can be made then.
|
||||
// rtt: the RTT for this ack packet.
|
||||
// minRTT: is the lowest delay (RTT) we have seen during the session.
|
||||
// congestionWindow: the congestion window in packets.
|
||||
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool {
|
||||
if !s.started {
|
||||
// Time to start the hybrid slow start.
|
||||
s.StartReceiveRound(s.lastSentPacketNumber)
|
||||
}
|
||||
if s.hystartFound {
|
||||
return true
|
||||
}
|
||||
// Second detection parameter - delay increase detection.
|
||||
// Compare the minimum delay (s.currentMinRTT) of the current
|
||||
// burst of packets relative to the minimum delay during the session.
|
||||
// Note: we only look at the first few(8) packets in each burst, since we
|
||||
// only want to compare the lowest RTT of the burst relative to previous
|
||||
// bursts.
|
||||
s.rttSampleCount++
|
||||
if s.rttSampleCount <= hybridStartMinSamples {
|
||||
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
|
||||
s.currentMinRTT = latestRTT
|
||||
}
|
||||
}
|
||||
// We only need to check this once per round.
|
||||
if s.rttSampleCount == hybridStartMinSamples {
|
||||
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
|
||||
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
|
||||
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
|
||||
minRTTincreaseThresholdUs = utils.MinInt64(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
|
||||
minRTTincreaseThreshold := time.Duration(utils.MaxInt64(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
|
||||
|
||||
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
|
||||
s.hystartFound = true
|
||||
}
|
||||
}
|
||||
// Exit from slow start if the cwnd is greater than 16 and
|
||||
// increasing delay is found.
|
||||
return congestionWindow >= hybridStartLowWindow && s.hystartFound
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet was sent
|
||||
func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) {
|
||||
s.lastSentPacketNumber = packetNumber
|
||||
}
|
||||
|
||||
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
|
||||
// the round when the final packet of the burst is received and start it on
|
||||
// the next incoming ack.
|
||||
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) {
|
||||
if s.IsEndOfRound(ackedPacketNumber) {
|
||||
s.started = false
|
||||
}
|
||||
}
|
||||
|
||||
// Started returns true if started
|
||||
func (s *HybridSlowStart) Started() bool {
|
||||
return s.started
|
||||
}
|
||||
|
||||
// Restart the slow start phase
|
||||
func (s *HybridSlowStart) Restart() {
|
||||
s.started = false
|
||||
s.hystartFound = false
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Hybrid slow start", func() {
|
||||
var (
|
||||
slowStart HybridSlowStart
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
slowStart = HybridSlowStart{}
|
||||
})
|
||||
|
||||
It("works in a simple case", func() {
|
||||
packetNumber := protocol.PacketNumber(1)
|
||||
endPacketNumber := protocol.PacketNumber(3)
|
||||
slowStart.StartReceiveRound(endPacketNumber)
|
||||
|
||||
packetNumber++
|
||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
|
||||
|
||||
// Test duplicates.
|
||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
|
||||
|
||||
packetNumber++
|
||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
|
||||
packetNumber++
|
||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
|
||||
|
||||
// Test without a new registered end_packet_number;
|
||||
packetNumber++
|
||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
|
||||
|
||||
endPacketNumber = 20
|
||||
slowStart.StartReceiveRound(endPacketNumber)
|
||||
for packetNumber < endPacketNumber {
|
||||
packetNumber++
|
||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeFalse())
|
||||
}
|
||||
packetNumber++
|
||||
Expect(slowStart.IsEndOfRound(packetNumber)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("works with delay", func() {
|
||||
rtt := 60 * time.Millisecond
|
||||
// We expect to detect the increase at +1/8 of the RTT; hence at a typical
|
||||
// RTT of 60ms the detection will happen at 67.5 ms.
|
||||
const hybridStartMinSamples = 8 // Number of acks required to trigger.
|
||||
|
||||
endPacketNumber := protocol.PacketNumber(1)
|
||||
endPacketNumber++
|
||||
slowStart.StartReceiveRound(endPacketNumber)
|
||||
|
||||
// Will not trigger since our lowest RTT in our burst is the same as the long
|
||||
// term RTT provided.
|
||||
for n := 0; n < hybridStartMinSamples; n++ {
|
||||
Expect(slowStart.ShouldExitSlowStart(rtt+time.Duration(n)*time.Millisecond, rtt, 100)).To(BeFalse())
|
||||
}
|
||||
endPacketNumber++
|
||||
slowStart.StartReceiveRound(endPacketNumber)
|
||||
for n := 1; n < hybridStartMinSamples; n++ {
|
||||
Expect(slowStart.ShouldExitSlowStart(rtt+(time.Duration(n)+10)*time.Millisecond, rtt, 100)).To(BeFalse())
|
||||
}
|
||||
// Expect to trigger since all packets in this burst was above the long term
|
||||
// RTT provided.
|
||||
Expect(slowStart.ShouldExitSlowStart(rtt+10*time.Millisecond, rtt, 100)).To(BeTrue())
|
||||
})
|
||||
|
||||
})
|
|
@ -0,0 +1,36 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// A SendAlgorithm performs congestion control and calculates the congestion window
|
||||
type SendAlgorithm interface {
|
||||
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
|
||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
|
||||
GetCongestionWindow() protocol.ByteCount
|
||||
MaybeExitSlowStart()
|
||||
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
|
||||
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
|
||||
SetNumEmulatedConnections(n int)
|
||||
OnRetransmissionTimeout(packetsRetransmitted bool)
|
||||
OnConnectionMigration()
|
||||
|
||||
// Experiments
|
||||
SetSlowStartLargeReduction(enabled bool)
|
||||
}
|
||||
|
||||
// SendAlgorithmWithDebugInfo adds some debug functions to SendAlgorithm
|
||||
type SendAlgorithmWithDebugInfo interface {
|
||||
SendAlgorithm
|
||||
BandwidthEstimate() Bandwidth
|
||||
|
||||
// Stuff only used in testing
|
||||
|
||||
HybridSlowStart() *HybridSlowStart
|
||||
SlowstartThreshold() protocol.ByteCount
|
||||
RenoBeta() float32
|
||||
InRecovery() bool
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
|
||||
type PrrSender struct {
|
||||
bytesSentSinceLoss protocol.ByteCount
|
||||
bytesDeliveredSinceLoss protocol.ByteCount
|
||||
ackCountSinceLoss protocol.ByteCount
|
||||
bytesInFlightBeforeLoss protocol.ByteCount
|
||||
}
|
||||
|
||||
// OnPacketSent should be called after a packet was sent
|
||||
func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) {
|
||||
p.bytesSentSinceLoss += sentBytes
|
||||
}
|
||||
|
||||
// OnPacketLost should be called on the first loss that triggers a recovery
|
||||
// period and all other methods in this class should only be called when in
|
||||
// recovery.
|
||||
func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) {
|
||||
p.bytesSentSinceLoss = 0
|
||||
p.bytesInFlightBeforeLoss = priorInFlight
|
||||
p.bytesDeliveredSinceLoss = 0
|
||||
p.ackCountSinceLoss = 0
|
||||
}
|
||||
|
||||
// OnPacketAcked should be called after a packet was acked
|
||||
func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) {
|
||||
p.bytesDeliveredSinceLoss += ackedBytes
|
||||
p.ackCountSinceLoss++
|
||||
}
|
||||
|
||||
// CanSend returns if packets can be sent
|
||||
func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool {
|
||||
// Return QuicTime::Zero In order to ensure limited transmit always works.
|
||||
if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS {
|
||||
return true
|
||||
}
|
||||
if congestionWindow > bytesInFlight {
|
||||
// During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead
|
||||
// of sending the entire available window. This prevents burst retransmits
|
||||
// when more packets are lost than the CWND reduction.
|
||||
// limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS
|
||||
return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss
|
||||
}
|
||||
// Implement Proportional Rate Reduction (RFC6937).
|
||||
// Checks a simplified version of the PRR formula that doesn't use division:
|
||||
// AvailableSendWindow =
|
||||
// CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent
|
||||
return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue