Compare commits

..

15 Commits

Author SHA1 Message Date
Darien Raymond
fd060a0880 temp fix deadlock in quic lib 2018-11-29 17:17:07 +01:00
Darien Raymond
3b1aaa9a3d update quic parameters 2018-11-28 23:14:41 +01:00
Darien Raymond
1d54fd7433 update timeout and window size in quic 2018-11-27 22:39:54 +01:00
Darien Raymond
828cfac630 remove unused file 2018-11-27 15:33:53 +01:00
Darien Raymond
135bf169c0 update quic vendor 2018-11-27 15:29:03 +01:00
Darien Raymond
90ab42b1cb close connection without error 2018-11-27 10:56:07 +01:00
Darien Raymond
fe0f32e0f1 vendor websocket 2018-11-27 10:08:34 +01:00
Darien Raymond
f70712a1a3 remove warning messages in quic 2018-11-26 18:00:41 +01:00
Darien Raymond
45fbf6f059 update quic connection handling 2018-11-25 21:56:24 +01:00
Darien Raymond
e0093206e7 update quic 2018-11-25 21:55:45 +01:00
Darien Raymond
5ecec6afd8 reduce concurrency 2018-11-24 23:10:04 +01:00
Darien Raymond
4ff26a36ad update quic connection handling 2018-11-24 22:18:17 +01:00
Darien Raymond
ad4908be9e disable logging completely 2018-11-24 22:17:41 +01:00
Darien Raymond
a1801b5d6f update quic config 2018-11-23 23:51:07 +01:00
Darien Raymond
275ab2ae8c fix inactive session removal 2018-11-23 20:47:36 +01:00
33 changed files with 3291 additions and 705 deletions

View File

@@ -162,10 +162,13 @@ func (c *interConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer reader.Close()
for {
nBytes, err := reader.Read(b[:1380])
nBytes, err := reader.Read(b[:1200])
if err != nil {
break
}
if nBytes == 0 {
continue
}
if _, err := c.Write(b[:nBytes]); err != nil {
return err
}

View File

@@ -6,54 +6,112 @@ import (
"time"
quic "github.com/lucas-clemente/quic-go"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/task"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/tls"
)
type clientSessions struct {
access sync.Mutex
sessions map[net.Destination][]quic.Session
type sessionContext struct {
rawConn *sysConn
session quic.Session
}
func removeInactiveSessions(sessions []quic.Session) []quic.Session {
lastActive := 0
var errSessionClosed = newError("session closed")
func (c *sessionContext) openStream(destAddr net.Addr) (*interConn, error) {
if !isActive(c.session) {
return nil, errSessionClosed
}
stream, err := c.session.OpenStream()
if err != nil {
return nil, err
}
conn := &interConn{
stream: stream,
local: c.session.LocalAddr(),
remote: destAddr,
}
return conn, nil
}
type clientSessions struct {
access sync.Mutex
sessions map[net.Destination][]*sessionContext
cleanup *task.Periodic
}
func isActive(s quic.Session) bool {
select {
case <-s.Context().Done():
return false
default:
return true
}
}
func removeInactiveSessions(sessions []*sessionContext) []*sessionContext {
activeSessions := make([]*sessionContext, 0, len(sessions))
for _, s := range sessions {
active := true
select {
case <-s.Context().Done():
active = false
default:
if isActive(s.session) {
activeSessions = append(activeSessions, s)
continue
}
if active {
sessions[lastActive] = s
lastActive++
if err := s.session.Close(); err != nil {
newError("failed to close session").Base(err).WriteToLog()
}
if err := s.rawConn.Close(); err != nil {
newError("failed to close raw connection").Base(err).WriteToLog()
}
}
if lastActive < len(sessions) {
for i := 0; i < len(sessions); i++ {
sessions[i] = nil
}
sessions = sessions[:lastActive]
if len(activeSessions) < len(sessions) {
return activeSessions
}
return sessions
}
func openStream(sessions []quic.Session) (quic.Stream, net.Addr, error) {
func openStream(sessions []*sessionContext, destAddr net.Addr) *interConn {
for _, s := range sessions {
stream, err := s.OpenStream()
if err != nil {
newError("failed to create stream").Base(err).WriteToLog()
if !isActive(s.session) {
continue
}
return stream, s.LocalAddr(), nil
conn, err := s.openStream(destAddr)
if err != nil {
continue
}
return conn
}
return nil, nil, nil
return nil
}
func (s *clientSessions) cleanSessions() error {
s.access.Lock()
defer s.access.Unlock()
if len(s.sessions) == 0 {
return nil
}
newSessionMap := make(map[net.Destination][]*sessionContext)
for dest, sessions := range s.sessions {
sessions = removeInactiveSessions(sessions)
if len(sessions) > 0 {
newSessionMap[dest] = sessions
}
}
s.sessions = newSessionMap
return nil
}
func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
@@ -61,30 +119,24 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo
defer s.access.Unlock()
if s.sessions == nil {
s.sessions = make(map[net.Destination][]quic.Session)
s.sessions = make(map[net.Destination][]*sessionContext)
}
dest := net.DestinationFromAddr(destAddr)
var sessions []quic.Session
var sessions []*sessionContext
if s, found := s.sessions[dest]; found {
sessions = s
}
sessions = removeInactiveSessions(sessions)
s.sessions[dest] = sessions
if true {
conn := openStream(sessions, destAddr)
if conn != nil {
return conn, nil
}
}
stream, local, err := openStream(sessions)
if err != nil {
return nil, err
}
if stream != nil {
return &interConn{
stream: stream,
local: local,
remote: destAddr,
}, nil
}
sessions = removeInactiveSessions(sessions)
rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
IP: []byte{0, 0, 0, 0},
@@ -95,13 +147,9 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo
}
quicConfig := &quic.Config{
ConnectionIDLength: 12,
KeepAlive: true,
HandshakeTimeout: time.Second * 4,
IdleTimeout: time.Second * 60,
MaxReceiveStreamFlowControlWindow: 256 * 1024,
MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
MaxIncomingUniStreams: -1,
ConnectionIDLength: 12,
HandshakeTimeout: time.Second * 8,
IdleTimeout: time.Second * 30,
}
conn, err := wrapSysConn(rawConn, config)
@@ -112,24 +160,29 @@ func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsCo
session, err := quic.DialContext(context.Background(), conn, destAddr, "", tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
if err != nil {
rawConn.Close()
conn.Close()
return nil, err
}
s.sessions[dest] = append(sessions, session)
stream, err = session.OpenStream()
if err != nil {
return nil, err
context := &sessionContext{
session: session,
rawConn: conn,
}
return &interConn{
stream: stream,
local: session.LocalAddr(),
remote: destAddr,
}, nil
s.sessions[dest] = append(sessions, context)
return context.openStream(destAddr)
}
var client clientSessions
func init() {
client.sessions = make(map[net.Destination][]*sessionContext)
client.cleanup = &task.Periodic{
Interval: time.Minute,
Execute: client.cleanSessions,
}
common.Must(client.cleanup.Start())
}
func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
if tlsConfig == nil {
@@ -139,9 +192,18 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
}
}
destAddr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
if err != nil {
return nil, err
var destAddr *net.UDPAddr
if dest.Address.Family().IsIP() {
destAddr = &net.UDPAddr{
IP: dest.Address.IP(),
Port: int(dest.Port),
}
} else {
addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
if err != nil {
return nil, err
}
destAddr = addr
}
config := streamSettings.ProtocolSettings.(*Config)

View File

@@ -8,13 +8,16 @@ import (
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol/tls/cert"
"v2ray.com/core/common/signal/done"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/tls"
)
// Listener is an internet.Listener that listens for TCP connections.
type Listener struct {
rawConn *sysConn
listener quic.Listener
done *done.Instance
addConn internet.ConnHandler
}
@@ -23,8 +26,16 @@ func (l *Listener) acceptStreams(session quic.Session) {
stream, err := session.AcceptStream()
if err != nil {
newError("failed to accept stream").Base(err).WriteToLog()
session.Close()
return
select {
case <-session.Context().Done():
return
case <-l.done.Wait():
session.Close()
return
default:
time.Sleep(time.Second)
continue
}
}
conn := &interConn{
@@ -43,8 +54,11 @@ func (l *Listener) keepAccepting() {
conn, err := l.listener.Accept()
if err != nil {
newError("failed to accept QUIC sessions").Base(err).WriteToLog()
l.listener.Close()
return
if l.done.Done() {
break
}
time.Sleep(time.Second)
continue
}
go l.acceptStreams(conn)
}
@@ -57,7 +71,10 @@ func (l *Listener) Addr() net.Addr {
// Close implements internet.Listener.Close.
func (l *Listener) Close() error {
return l.listener.Close()
l.done.Close()
l.listener.Close()
l.rawConn.Close()
return nil
}
// Listen creates a new Listener based on configurations.
@@ -84,14 +101,11 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
}
quicConfig := &quic.Config{
ConnectionIDLength: 12,
KeepAlive: true,
HandshakeTimeout: time.Second * 4,
IdleTimeout: time.Second * 60,
MaxReceiveStreamFlowControlWindow: 256 * 1024,
MaxReceiveConnectionFlowControlWindow: 2 * 1024 * 1024,
MaxIncomingStreams: 64,
MaxIncomingUniStreams: -1,
ConnectionIDLength: 12,
HandshakeTimeout: time.Second * 8,
IdleTimeout: time.Second * 30,
MaxIncomingStreams: 128,
MaxIncomingUniStreams: 32,
}
conn, err := wrapSysConn(rawConn, config)
@@ -102,11 +116,13 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
qListener, err := quic.Listen(conn, tlsConfig.GetTLSConfig(), quicConfig)
if err != nil {
rawConn.Close()
conn.Close()
return nil, err
}
listener := &Listener{
done: done.New(),
rawConn: conn,
listener: qListener,
addConn: handler,
}

9
vendor/github.com/gorilla/websocket/AUTHORS generated vendored Normal file
View File

@@ -0,0 +1,9 @@
# This is the official list of Gorilla WebSocket authors for copyright
# purposes.
#
# Please keep the list sorted.
Gary Burd <gary@beagledreams.com>
Google LLC (https://opensource.google.com/)
Joachim Bauch <mail@joachim-bauch.de>

22
vendor/github.com/gorilla/websocket/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,22 @@
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

64
vendor/github.com/gorilla/websocket/README.md generated vendored Normal file
View File

@@ -0,0 +1,64 @@
# Gorilla WebSocket
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
[![Build Status](https://travis-ci.org/gorilla/websocket.svg?branch=master)](https://travis-ci.org/gorilla/websocket)
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
### Documentation
* [API Reference](http://godoc.org/github.com/gorilla/websocket)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
### Status
The Gorilla WebSocket package provides a complete and tested implementation of
the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The
package API is stable.
### Installation
go get github.com/gorilla/websocket
### Protocol Compliance
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
### Gorilla WebSocket compared with other packages
<table>
<tr>
<th></th>
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
</tr>
<tr>
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
<tr><td>Passes <a href="http://autobahn.ws/testsuite/">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
<tr><td colspan="3">Other Features</tr></td>
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
</table>
Notes:
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
2. The application can get the type of a received data message by implementing
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
function.
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
Read returns when the input buffer is full or a frame boundary is
encountered. Each call to Write sends a single frame message. The Gorilla
io.Reader and io.WriteCloser operate on a single WebSocket message.

396
vendor/github.com/gorilla/websocket/client.go generated vendored Normal file
View File

@@ -0,0 +1,396 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// ErrBadHandshake is returned when the server response to opening handshake is
// invalid.
var ErrBadHandshake = errors.New("websocket: bad handshake")
var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
// NewClient creates a new client connection using the given net connection.
// The URL u specifies the host and request URI. Use requestHeader to specify
// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
// (Cookie). Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etc.
//
// Deprecated: Use Dialer instead.
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
d := Dialer{
ReadBufferSize: readBufSize,
WriteBufferSize: writeBufSize,
NetDial: func(net, addr string) (net.Conn, error) {
return netConn, nil
},
}
return d.Dial(u.String(), requestHeader)
}
// A Dialer contains options for connecting to WebSocket server.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, net.DialContext is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
// If Proxy is nil or returns a nil *URL, no proxy is used.
// Proxy func(*http.Request) (*url.URL, error)
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then a useful default size is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
// EnableCompression specifies if the client should attempt to negotiate
// per message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
// EnableCompression bool
// Jar specifies the cookie jar.
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar http.CookieJar
}
// Dial creates a new client connection by calling DialContext with a background context.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
return d.DialContext(context.Background(), urlStr, requestHeader)
}
var errMalformedURL = errors.New("malformed ws or wss URL")
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
hostNoPort = hostNoPort[:i]
} else {
switch u.Scheme {
case "wss":
hostPort += ":443"
case "https":
hostPort += ":443"
default:
hostPort += ":80"
}
}
return hostPort, hostNoPort
}
// DefaultDialer is a dialer with all fields set to the default values.
var DefaultDialer = &Dialer{
// Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
}
// nilDialer is dialer to use when receiver is nil.
var nilDialer = *DefaultDialer
// DialContext creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
// Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// The context will be used in the request and in the Dialer.
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etcetera. The response body may not contain the entire response and does not
// need to be closed by the application.
func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil {
d = &nilDialer
}
challengeKey, err := generateChallengeKey()
if err != nil {
return nil, nil, err
}
u, err := url.Parse(urlStr)
if err != nil {
return nil, nil, err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return nil, nil, errMalformedURL
}
if u.User != nil {
// User name and password are not allowed in websocket URIs.
return nil, nil, errMalformedURL
}
req := &http.Request{
Method: "GET",
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: u.Host,
}
req = req.WithContext(ctx)
// Set the cookies present in the cookie jar of the dialer
if d.Jar != nil {
for _, cookie := range d.Jar.Cookies(u) {
req.AddCookie(cookie)
}
}
// Set the request headers using the capitalization for names and values in
// RFC examples. Although the capitalization shouldn't matter, there are
// servers that depend on it. The Header.Set method is not used because the
// method canonicalizes the header names.
req.Header["Upgrade"] = []string{"websocket"}
req.Header["Connection"] = []string{"Upgrade"}
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
req.Header["Sec-WebSocket-Version"] = []string{"13"}
if len(d.Subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
}
for k, vs := range requestHeader {
switch {
case k == "Host":
if len(vs) > 0 {
req.Host = vs[0]
}
case k == "Upgrade" ||
k == "Connection" ||
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
case k == "Sec-Websocket-Protocol":
req.Header["Sec-WebSocket-Protocol"] = vs
default:
req.Header[k] = vs
}
}
//if d.EnableCompression {
// req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
//}
if d.HandshakeTimeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
defer cancel()
}
// Get network dial function.
var netDial func(network, add string) (net.Conn, error)
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
} else {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
}
}
// If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok {
forwardDial := netDial
netDial = func(network, addr string) (net.Conn, error) {
c, err := forwardDial(network, addr)
if err != nil {
return nil, err
}
err = c.SetDeadline(deadline)
if err != nil {
c.Close()
return nil, err
}
return c, nil
}
}
// If needed, wrap the dial function to connect through a proxy.
/*
if d.Proxy != nil {
proxyURL, err := d.Proxy(req)
if err != nil {
return nil, nil, err
}
if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
if err != nil {
return nil, nil, err
}
netDial = dialer.Dial
}
}*/
hostPort, hostNoPort := hostPortNoPort(u)
//trace := httptrace.ContextClientTrace(ctx)
//if trace != nil && trace.GetConn != nil {
// trace.GetConn(hostPort)
//}
netConn, err := netDial("tcp", hostPort)
//if trace != nil && trace.GotConn != nil {
// trace.GotConn(httptrace.GotConnInfo{
// Conn: netConn,
// })
//}
if err != nil {
return nil, nil, err
}
defer func() {
if netConn != nil {
netConn.Close()
}
}()
if u.Scheme == "https" {
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
var err error
//if trace != nil {
// err = doHandshakeWithTrace(trace, tlsConn, cfg)
//} else {
err = doHandshake(tlsConn, cfg)
//}
if err != nil {
return nil, nil, err
}
}
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
if err := req.Write(netConn); err != nil {
return nil, nil, err
}
//if trace != nil && trace.GotFirstResponseByte != nil {
// if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
// trace.GotFirstResponseByte()
// }
//}
resp, err := http.ReadResponse(conn.br, req)
if err != nil {
return nil, nil, err
}
if d.Jar != nil {
if rc := resp.Cookies(); len(rc) > 0 {
d.Jar.SetCookies(u, rc)
}
}
if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
// debugging.
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake
}
/*
for _, ext := range parseExtensions(resp.Header) {
if ext[""] != "permessage-deflate" {
continue
}
_, snct := ext["server_no_context_takeover"]
_, cnct := ext["client_no_context_takeover"]
if !snct || !cnct {
return nil, resp, errInvalidCompression
}
conn.newCompressionWriter = compressNoContextTakeover
conn.newDecompressionReader = decompressNoContextTakeover
break
}*/
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer.
return conn, resp, nil
}
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

16
vendor/github.com/gorilla/websocket/client_clone.go generated vendored Normal file
View File

@@ -0,0 +1,16 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

1143
vendor/github.com/gorilla/websocket/conn.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

15
vendor/github.com/gorilla/websocket/conn_write.go generated vendored Normal file
View File

@@ -0,0 +1,15 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "net"
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}

180
vendor/github.com/gorilla/websocket/doc.go generated vendored Normal file
View File

@@ -0,0 +1,180 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package websocket implements the WebSocket protocol defined in RFC 6455.
//
// Overview
//
// The Conn type represents a WebSocket connection. A server application calls
// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn:
//
// var upgrader = websocket.Upgrader{
// ReadBufferSize: 1024,
// WriteBufferSize: 1024,
// }
//
// func handler(w http.ResponseWriter, r *http.Request) {
// conn, err := upgrader.Upgrade(w, r, nil)
// if err != nil {
// log.Println(err)
// return
// }
// ... Use conn to send and receive messages.
// }
//
// Call the connection's WriteMessage and ReadMessage methods to send and
// receive messages as a slice of bytes. This snippet of code shows how to echo
// messages using these methods:
//
// for {
// messageType, p, err := conn.ReadMessage()
// if err != nil {
// log.Println(err)
// return
// }
// if err := conn.WriteMessage(messageType, p); err != nil {
// log.Println(err)
// return
// }
// }
//
// In above snippet of code, p is a []byte and messageType is an int with value
// websocket.BinaryMessage or websocket.TextMessage.
//
// An application can also send and receive messages using the io.WriteCloser
// and io.Reader interfaces. To send a message, call the connection NextWriter
// method to get an io.WriteCloser, write the message to the writer and close
// the writer when done. To receive a message, call the connection NextReader
// method to get an io.Reader and read until io.EOF is returned. This snippet
// shows how to echo messages using the NextWriter and NextReader methods:
//
// for {
// messageType, r, err := conn.NextReader()
// if err != nil {
// return
// }
// w, err := conn.NextWriter(messageType)
// if err != nil {
// return err
// }
// if _, err := io.Copy(w, r); err != nil {
// return err
// }
// if err := w.Close(); err != nil {
// return err
// }
// }
//
// Data Messages
//
// The WebSocket protocol distinguishes between text and binary data messages.
// Text messages are interpreted as UTF-8 encoded text. The interpretation of
// binary messages is left to the application.
//
// This package uses the TextMessage and BinaryMessage integer constants to
// identify the two data message types. The ReadMessage and NextReader methods
// return the type of the received message. The messageType argument to the
// WriteMessage and NextWriter methods specifies the type of a sent message.
//
// It is the application's responsibility to ensure that text messages are
// valid UTF-8 encoded text.
//
// Control Messages
//
// The WebSocket protocol defines three types of control messages: close, ping
// and pong. Call the connection WriteControl, WriteMessage or NextWriter
// methods to send a control message to the peer.
//
// Connections handle received close messages by calling the handler function
// set with the SetCloseHandler method and by returning a *CloseError from the
// NextReader, ReadMessage or the message Read method. The default close
// handler sends a close message to the peer.
//
// Connections handle received ping messages by calling the handler function
// set with the SetPingHandler method. The default ping handler sends a pong
// message to the peer.
//
// Connections handle received pong messages by calling the handler function
// set with the SetPongHandler method. The default pong handler does nothing.
// If an application sends ping messages, then the application should set a
// pong handler to receive the corresponding pong.
//
// The control message handler functions are called from the NextReader,
// ReadMessage and message reader Read methods. The default close and ping
// handlers can block these methods for a short time when the handler writes to
// the connection.
//
// The application must read the connection to process close, ping and pong
// messages sent from the peer. If the application is not otherwise interested
// in messages from the peer, then the application should start a goroutine to
// read and discard messages from the peer. A simple example is:
//
// func readLoop(c *websocket.Conn) {
// for {
// if _, _, err := c.NextReader(); err != nil {
// c.Close()
// break
// }
// }
// }
//
// Concurrency
//
// Connections support one concurrent reader and one concurrent writer.
//
// Applications are responsible for ensuring that no more than one goroutine
// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage,
// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and
// that no more than one goroutine calls the read methods (NextReader,
// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler)
// concurrently.
//
// The Close and WriteControl methods can be called concurrently with all other
// methods.
//
// Origin Considerations
//
// Web browsers allow Javascript applications to open a WebSocket connection to
// any host. It's up to the server to enforce an origin policy using the Origin
// request header sent by the browser.
//
// The Upgrader calls the function specified in the CheckOrigin field to check
// the origin. If the CheckOrigin function returns false, then the Upgrade
// method fails the WebSocket handshake with HTTP status 403.
//
// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail
// the handshake if the Origin request header is present and the Origin host is
// not equal to the Host request header.
//
// The deprecated package-level Upgrade function does not perform origin
// checking. The application is responsible for checking the Origin header
// before calling the Upgrade function.
//
// Compression EXPERIMENTAL
//
// Per message compression extensions (RFC 7692) are experimentally supported
// by this package in a limited capacity. Setting the EnableCompression option
// to true in Dialer or Upgrader will attempt to negotiate per message deflate
// support.
//
// var upgrader = websocket.Upgrader{
// EnableCompression: true,
// }
//
// If compression was successfully negotiated with the connection's peer, any
// message received in compressed form will be automatically decompressed.
// All Read methods will return uncompressed bytes.
//
// Per message compression of messages written to a connection can be enabled
// or disabled by calling the corresponding Conn method:
//
// conn.EnableWriteCompression(false)
//
// Currently this package does not support compression with "context takeover".
// This means that messages must be compressed and decompressed in isolation,
// without retaining sliding window or dictionary state across messages. For
// more details refer to RFC 7692.
//
// Use of compression is experimental and may result in decreased performance.
package websocket

54
vendor/github.com/gorilla/websocket/mask.go generated vendored Normal file
View File

@@ -0,0 +1,54 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// +build !appengine
package websocket
import "unsafe"
const wordSize = int(unsafe.Sizeof(uintptr(0)))
func maskBytes(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers.
if len(b) < 2*wordSize {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
// Mask one byte at a time to word boundary.
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n
for i := range b[:n] {
b[i] ^= key[pos&3]
pos++
}
b = b[n:]
}
// Create aligned word size key.
var k [wordSize]byte
for i := range k {
k[i] = key[(pos+i)&3]
}
kw := *(*uintptr)(unsafe.Pointer(&k))
// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize {
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
}
// Mask one byte at a time for remaining bytes.
b = b[n:]
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}

15
vendor/github.com/gorilla/websocket/mask_safe.go generated vendored Normal file
View File

@@ -0,0 +1,15 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// +build appengine
package websocket
func maskBytes(key [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}

363
vendor/github.com/gorilla/websocket/server.go generated vendored Normal file
View File

@@ -0,0 +1,363 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"errors"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// HandshakeError describes an error with the handshake from the peer.
type HandshakeError struct {
message string
}
func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent
// or received.
ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is not nil, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
// requested by the client. If there's no match, then no protocol is
// negotiated (the Sec-Websocket-Protocol header is not included in the
// handshake response).
Subprotocols []string
// Error specifies the function for generating HTTP error responses. If Error
// is nil, then http.Error is used to generate the HTTP response.
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, then a safe default is used: return false if the
// Origin request header is present and the origin host is not equal to
// request Host header.
//
// A CheckOrigin function should carefully validate the request origin to
// prevent cross-site request forgery.
CheckOrigin func(r *http.Request) bool
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
// EnableCompression bool
}
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
err := HandshakeError{reason}
if u.Error != nil {
u.Error(w, r, status, err)
} else {
w.Header().Set("Sec-Websocket-Version", "13")
http.Error(w, http.StatusText(status), status)
}
return nil, err
}
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return equalASCIIFold(u.Host, r.Host)
}
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil {
clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol {
return clientProtocol
}
}
}
} else if responseHeader != nil {
return responseHeader.Get("Sec-Websocket-Protocol")
}
return ""
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-WebSocket-Protocol).
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
const badHandshake = "websocket: the client is not using the websocket protocol: "
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
}
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
}
if r.Method != "GET" {
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
}
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
}
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
}
checkOrigin := u.CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOrigin
}
if !checkOrigin(r) {
return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
}
challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank")
}
subprotocol := u.selectSubprotocol(r, responseHeader)
// Negotiate PMCE
const compress bool = false
//if u.EnableCompression {
// for _, ext := range parseExtensions(r.Header) {
// if ext[""] != "permessage-deflate" {
// continue
// }
// compress = true
// break
// }
//}
h, ok := w.(http.Hijacker)
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
}
if brw.Reader.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
var br *bufio.Reader
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
// Reuse hijacked buffered reader as connection reader.
br = brw.Reader
}
buf := bufioWriterBuffer(netConn, brw.Writer)
var writeBuf []byte
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
// Reuse hijacked write buffer as connection buffer.
writeBuf = buf
}
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
c.subprotocol = subprotocol
//if compress {
// c.newCompressionWriter = compressNoContextTakeover
// c.newDecompressionReader = decompressNoContextTakeover
//}
// Use larger of hijacked buffer and connection write buffer for header.
p := buf
if len(c.writeBuf) > len(p) {
p = c.writeBuf
}
p = p[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...)
if c.subprotocol != "" {
p = append(p, "Sec-WebSocket-Protocol: "...)
p = append(p, c.subprotocol...)
p = append(p, "\r\n"...)
}
//if compress {
// p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
//}
for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" {
continue
}
for _, v := range vs {
p = append(p, k...)
p = append(p, ": "...)
for i := 0; i < len(v); i++ {
b := v[i]
if b <= 31 {
// prevent response splitting.
b = ' '
}
p = append(p, b)
}
p = append(p, "\r\n"...)
}
}
p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if _, err = netConn.Write(p); err != nil {
netConn.Close()
return nil, err
}
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
}
return c, nil
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// Deprecated: Use websocket.Upgrader instead.
//
// Upgrade does not perform origin checking. The application is responsible for
// checking the Origin header before calling Upgrade. An example implementation
// of the same origin policy check is:
//
// if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", http.StatusForbidden)
// return
// }
//
// If the endpoint supports subprotocols, then the application is responsible
// for negotiating the protocol used on the connection. Use the Subprotocols()
// function to get the subprotocols requested by the client. Use the
// Sec-Websocket-Protocol response header to specify the subprotocol selected
// by the application.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// negotiated subprotocol (Sec-Websocket-Protocol).
//
// The connection buffers IO to the underlying network connection. The
// readBufSize and writeBufSize parameters specify the size of the buffers to
// use. Messages can be larger than the buffers.
//
// If the request is not a valid WebSocket handshake, then Upgrade returns an
// error of type HandshakeError. Applications should handle this error by
// replying to the client with an HTTP error response.
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
// don't return errors to maintain backwards compatibility
}
u.CheckOrigin = func(r *http.Request) bool {
// allow all connections by default
return true
}
return u.Upgrade(w, r, responseHeader)
}
// Subprotocols returns the subprotocols requested by the client in the
// Sec-Websocket-Protocol header.
func Subprotocols(r *http.Request) []string {
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
if h == "" {
return nil
}
protocols := strings.Split(h, ",")
for i := range protocols {
protocols[i] = strings.TrimSpace(protocols[i])
}
return protocols
}
// IsWebSocketUpgrade returns true if the client requested upgrade to the
// WebSocket protocol.
func IsWebSocketUpgrade(r *http.Request) bool {
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
tokenListContainsValue(r.Header, "Upgrade", "websocket")
}
// bufioReaderSize size returns the size of a bufio.Reader.
func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
// This code assumes that peek on a reset reader returns
// bufio.Reader.buf[:0].
// TODO: Use bufio.Reader.Size() after Go 1.10
br.Reset(originalReader)
if p, err := br.Peek(0); err == nil {
return cap(p)
}
return 0
}
// writeHook is an io.Writer that records the last slice passed to it vio
// io.Writer.Write.
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// This code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
bw.Reset(&wh)
bw.WriteByte(0)
bw.Flush()
bw.Reset(originalWriter)
return wh.p[:cap(wh.p)]
}

283
vendor/github.com/gorilla/websocket/util.go generated vendored Normal file
View File

@@ -0,0 +1,283 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"io"
"net/http"
"strings"
"unicode/utf8"
)
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func generateChallengeKey() (string, error) {
p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(p), nil
}
// Token octets per RFC 2616.
var isTokenOctet = [256]bool{
'!': true,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'*': true,
'+': true,
'-': true,
'.': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
}
// skipSpace returns a slice of the string s with all leading RFC 2616 linear
// whitespace removed.
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if b := s[i]; b != ' ' && b != '\t' {
break
}
}
return s[i:]
}
// nextToken returns the leading RFC 2616 token of s and the string following
// the token.
func nextToken(s string) (token, rest string) {
i := 0
for ; i < len(s); i++ {
if !isTokenOctet[s[i]] {
break
}
}
return s[:i], s[i:]
}
// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616
// and the string following the token or quoted string.
func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") {
return nextToken(s)
}
s = s[1:]
for i := 0; i < len(s); i++ {
switch s[i] {
case '"':
return s[:i], s[i+1:]
case '\\':
p := make([]byte, len(s)-1)
j := copy(p, s[:i])
escape := true
for i = i + 1; i < len(s); i++ {
b := s[i]
switch {
case escape:
escape = false
p[j] = b
j++
case b == '\\':
escape = true
case b == '"':
return string(p[:j]), s[i+1:]
default:
p[j] = b
j++
}
}
return "", ""
}
}
return "", ""
}
// equalASCIIFold returns true if s is equal to t with ASCII case folding as
// defined in RFC 4790.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
// tokenListContainsValue returns true if the 1#token header with the given
// name contains a token equal to value with ASCII case folding.
func tokenListContainsValue(header http.Header, name string, value string) bool {
headers:
for _, s := range header[name] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
s = skipSpace(s)
if s != "" && s[0] != ',' {
continue headers
}
if equalASCIIFold(t, value) {
return true
}
if s == "" {
continue headers
}
s = s[1:]
}
}
return false
}
// parseExtensions parses WebSocket extensions from a header.
func parseExtensions(header http.Header) []map[string]string {
// From RFC 6455:
//
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
var result []map[string]string
headers:
for _, s := range header["Sec-Websocket-Extensions"] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
continue headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
continue headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
continue headers
}
result = append(result, ext)
if s == "" {
continue headers
}
s = s[1:]
}
}
return result
}

View File

@@ -27,7 +27,7 @@ type client struct {
token []byte
versionNegotiated bool // has the server accepted our version
versionNegotiated utils.AtomicBool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
@@ -59,6 +59,7 @@ var (
)
// DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC session is closed.
// The hostname for SNI is taken from the given address.
func DialAddr(
addr string,
@@ -69,7 +70,7 @@ func DialAddr(
}
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
// The hostname for SNI is taken from the given address.
// See DialAddr for details.
func DialAddrContext(
ctx context.Context,
addr string,
@@ -88,6 +89,8 @@ func DialAddrContext(
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The same PacketConn can be used for multiple calls to Dial and Listen,
// QUIC connection IDs are used for demultiplexing the different connections.
// The host parameter is used for SNI.
func Dial(
pconn net.PacketConn,
@@ -100,7 +103,7 @@ func Dial(
}
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
// The host parameter is used for SNI.
// See Dial for details.
func DialContext(
ctx context.Context,
pconn net.PacketConn,
@@ -164,7 +167,18 @@ func newClient(
}
}
}
srcConnID, err := generateConnectionID(config.ConnectionIDLength)
if err != nil {
return nil, err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return nil, err
}
c := &client{
srcConnID: srcConnID,
destConnID: destConnID,
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn,
tlsConf: tlsConf,
@@ -173,7 +187,7 @@ func newClient(
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, c.generateConnectionIDs()
return c, nil
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
@@ -234,20 +248,6 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
}
}
func (c *client) generateConnectionIDs() error {
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
if err != nil {
return err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return err
}
c.srcConnID = srcConnID
c.destConnID = destConnID
return nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
@@ -292,65 +292,49 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
}
func (c *client) handlePacket(p *receivedPacket) {
if err := c.handlePacketImpl(p); err != nil {
c.logger.Errorf("error handling packet: %s", err)
}
}
func (c *client) handlePacketImpl(p *receivedPacket) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// handle Version Negotiation Packets
if p.header.IsVersionNegotiation {
err := c.handleVersionNegotiationPacket(p.header)
if err != nil {
c.session.destroy(err)
}
// version negotiation packets have no payload
return err
if p.hdr.IsVersionNegotiation() {
go c.handleVersionNegotiationPacket(p.hdr)
return
}
// reject packets with the wrong connection ID
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
if p.header.Type == protocol.PacketTypeRetry {
c.handleRetryPacket(p.header)
return nil
if p.hdr.Type == protocol.PacketTypeRetry {
go c.handleRetryPacket(p.hdr)
return
}
// this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated {
c.versionNegotiated = true
if !c.versionNegotiated.Get() {
c.versionNegotiated.Set(true)
}
c.session.handlePacket(p)
return nil
}
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
c.mutex.Lock()
defer c.mutex.Unlock()
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
c.logger.Debugf("Received a delayed Version Negotiation Packet.")
return nil
if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
c.logger.Debugf("Received a delayed Version Negotiation packet.")
return
}
for _, v := range hdr.SupportedVersions {
if v == c.version {
// the version negotiation packet contains the version that we offered
// this might be a packet sent by an attacker (or by a terribly broken server implementation)
// ignore it
return nil
// The Version Negotiation packet contains the version that we offered.
// This might be a packet sent by an attacker (or by a terribly broken server implementation).
return
}
}
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
return qerr.InvalidVersion
c.session.destroy(qerr.InvalidVersion)
c.logger.Debugf("No compatible version found.")
return
}
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
@@ -358,18 +342,17 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
if err := c.generateConnectionIDs(); err != nil {
return err
}
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion)
return nil
}
func (c *client) handleRetryPacket(hdr *wire.Header) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.logger.Debugf("<- Received Retry")
hdr.Log(c.logger)
(&wire.ExtendedHeader{Header: *hdr}).Log(c.logger)
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
return

View File

@@ -75,12 +75,10 @@ type sentPacketHandler struct {
alarm time.Time
logger utils.Logger
version protocol.VersionNumber
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
@@ -95,7 +93,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
}
}
@@ -516,7 +513,7 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
pn := h.packetNumberGenerator.Peek()
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked())
}
func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {

View File

@@ -11,11 +11,13 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
c = &tls.Config{}
}
// QUIC requires TLS 1.3 or newer
if c.MinVersion < qtls.VersionTLS13 {
c.MinVersion = qtls.VersionTLS13
minVersion := c.MinVersion
if minVersion < qtls.VersionTLS13 {
minVersion = qtls.VersionTLS13
}
if c.MaxVersion < qtls.VersionTLS13 {
c.MaxVersion = qtls.VersionTLS13
maxVersion := c.MaxVersion
if maxVersion < qtls.VersionTLS13 {
maxVersion = qtls.VersionTLS13
}
return &qtls.Config{
Rand: c.Rand,
@@ -38,8 +40,8 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
MinVersion: minVersion,
MaxVersion: maxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,

View File

@@ -1,20 +1,37 @@
package protocol
// PacketNumberLen is the length of the packet number in bytes
type PacketNumberLen uint8
const (
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
PacketNumberLenInvalid PacketNumberLen = 0
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 PacketNumberLen = 2
// PacketNumberLen3 is a packet number length of 3 bytes
PacketNumberLen3 PacketNumberLen = 3
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 PacketNumberLen = 4
)
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func InferPacketNumber(
packetNumberLength PacketNumberLen,
lastPacketNumber PacketNumber,
wirePacketNumber PacketNumber,
version VersionNumber,
) PacketNumber {
var epochDelta PacketNumber
switch packetNumberLength {
case PacketNumberLen1:
epochDelta = PacketNumber(1) << 7
epochDelta = PacketNumber(1) << 8
case PacketNumberLen2:
epochDelta = PacketNumber(1) << 14
epochDelta = PacketNumber(1) << 16
case PacketNumberLen3:
epochDelta = PacketNumber(1) << 24
case PacketNumberLen4:
epochDelta = PacketNumber(1) << 30
epochDelta = PacketNumber(1) << 32
}
epoch := lastPacketNumber & ^(epochDelta - 1)
prevEpochBegin := epoch - epochDelta
@@ -42,11 +59,14 @@ func delta(a, b PacketNumber) PacketNumber {
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
diff := uint64(packetNumber - leastUnacked)
if diff < (1 << (14 - 1)) {
if diff < (1 << (16 - 1)) {
return PacketNumberLen2
}
if diff < (1 << (24 - 1)) {
return PacketNumberLen3
}
return PacketNumberLen4
}
@@ -58,5 +78,8 @@ func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen {
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
return PacketNumberLen2
}
if packetNumber < (1 << (uint8(PacketNumberLen3) * 8)) {
return PacketNumberLen3
}
return PacketNumberLen4
}

View File

@@ -7,32 +7,18 @@ import (
// A PacketNumber in QUIC
type PacketNumber uint64
// PacketNumberLen is the length of the packet number in bytes
type PacketNumberLen uint8
const (
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
PacketNumberLenInvalid PacketNumberLen = 0
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 PacketNumberLen = 2
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 PacketNumberLen = 4
)
// The PacketType is the Long Header Type
type PacketType uint8
const (
// PacketTypeInitial is the packet type of an Initial packet
PacketTypeInitial PacketType = 0x7f
PacketTypeInitial PacketType = 1 + iota
// PacketTypeRetry is the packet type of a Retry packet
PacketTypeRetry PacketType = 0x7e
PacketTypeRetry
// PacketTypeHandshake is the packet type of a Handshake packet
PacketTypeHandshake PacketType = 0x7d
PacketTypeHandshake
// PacketType0RTT is the packet type of a 0-RTT packet
PacketType0RTT PacketType = 0x7c
PacketType0RTT
)
func (t PacketType) String() string {
@@ -72,11 +58,5 @@ const DefaultTCPMSS ByteCount = 1460
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
const MinInitialPacketSize = 1200
// MaxClientHellos is the maximum number of times we'll send a client hello
// The value 3 accounts for:
// * one failure due to an incorrect or missing source-address token
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
const MaxClientHellos = 3
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8

View File

@@ -8,11 +8,10 @@ import (
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
type ByteOrder interface {
ReadUintN(b io.ByteReader, length uint8) (uint64, error)
ReadUint64(io.ByteReader) (uint64, error)
ReadUint32(io.ByteReader) (uint32, error)
ReadUint16(io.ByteReader) (uint16, error)
WriteUint64(*bytes.Buffer, uint64)
WriteUintN(b *bytes.Buffer, length uint8, value uint64)
WriteUint32(*bytes.Buffer, uint32)
WriteUint16(*bytes.Buffer, uint16)
}

View File

@@ -25,37 +25,6 @@ func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
return res, nil
}
// ReadUint64 reads a uint64
func (bigEndian) ReadUint64(b io.ByteReader) (uint64, error) {
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
var err error
if b8, err = b.ReadByte(); err != nil {
return 0, err
}
if b7, err = b.ReadByte(); err != nil {
return 0, err
}
if b6, err = b.ReadByte(); err != nil {
return 0, err
}
if b5, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
}
// ReadUint32 reads a uint32
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8
@@ -88,12 +57,10 @@ func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
return uint16(b1) + uint16(b2)<<8, nil
}
// WriteUint64 writes a uint64
func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
b.Write([]byte{
uint8(i >> 56), uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
func (bigEndian) WriteUintN(b *bytes.Buffer, length uint8, i uint64) {
for j := length; j > 0; j-- {
b.WriteByte(uint8(i >> (8 * (j - 1))))
}
}
// WriteUint32 writes a uint32

View File

@@ -1,13 +1,5 @@
package utils
import (
"fmt"
"log"
"os"
"strings"
"time"
)
// LogLevel of quic-go
type LogLevel uint8
@@ -22,8 +14,6 @@ const (
LogLevelDebug
)
const logEnv = "QUIC_GO_LOG_LEVEL"
// A Logger logs.
type Logger interface {
SetLogLevel(LogLevel)
@@ -39,93 +29,36 @@ type Logger interface {
// DefaultLogger is used by quic-go for logging.
var DefaultLogger Logger
type defaultLogger struct {
prefix string
logLevel LogLevel
timeFormat string
}
type defaultLogger struct{}
var _ Logger = &defaultLogger{}
// SetLogLevel sets the log level
func (l *defaultLogger) SetLogLevel(level LogLevel) {
l.logLevel = level
}
// SetLogTimeFormat sets the format of the timestamp
// an empty string disables the logging of timestamps
func (l *defaultLogger) SetLogTimeFormat(format string) {
log.SetFlags(0) // disable timestamp logging done by the log package
l.timeFormat = format
}
func (l *defaultLogger) SetLogTimeFormat(format string) {}
// Debugf logs something
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
if l.logLevel == LogLevelDebug {
l.logMessage(format, args...)
}
}
func (l *defaultLogger) Debugf(format string, args ...interface{}) {}
// Infof logs something
func (l *defaultLogger) Infof(format string, args ...interface{}) {
if l.logLevel >= LogLevelInfo {
l.logMessage(format, args...)
}
}
func (l *defaultLogger) Infof(format string, args ...interface{}) {}
// Errorf logs something
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
if l.logLevel >= LogLevelError {
l.logMessage(format, args...)
}
}
func (l *defaultLogger) logMessage(format string, args ...interface{}) {
var pre string
if len(l.timeFormat) > 0 {
pre = time.Now().Format(l.timeFormat) + " "
}
if len(l.prefix) > 0 {
pre += l.prefix + " "
}
log.Printf(pre+format, args...)
}
func (l *defaultLogger) Errorf(format string, args ...interface{}) {}
func (l *defaultLogger) WithPrefix(prefix string) Logger {
if len(l.prefix) > 0 {
prefix = l.prefix + " " + prefix
}
return &defaultLogger{
logLevel: l.logLevel,
timeFormat: l.timeFormat,
prefix: prefix,
}
return l
}
// Debug returns true if the log level is LogLevelDebug
func (l *defaultLogger) Debug() bool {
return l.logLevel == LogLevelDebug
return false
}
func init() {
DefaultLogger = &defaultLogger{}
DefaultLogger.SetLogLevel(readLoggingEnv())
}
func readLoggingEnv() LogLevel {
switch strings.ToLower(os.Getenv(logEnv)) {
case "":
return LogLevelNothing
case "debug":
return LogLevelDebug
case "info":
return LogLevelInfo
case "error":
return LogLevelError
default:
fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging")
return LogLevelNothing
}
}

View File

@@ -0,0 +1,205 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// ExtendedHeader is the header of a QUIC packet.
type ExtendedHeader struct {
Header
typeByte byte
Raw []byte
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
KeyPhase int
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
// read the (now unencrypted) first byte
var err error
h.typeByte, err = b.ReadByte()
if err != nil {
return nil, err
}
if _, err := b.Seek(int64(h.len)-1, io.SeekCurrent); err != nil {
return nil, err
}
if h.IsLongHeader {
return h.parseLongHeader(b, v)
}
return h.parseShortHeader(b, v)
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
if h.typeByte&0xc != 0 {
return nil, errors.New("5th and 6th bit must be 0")
}
if err := h.readPacketNumber(b); err != nil {
return nil, err
}
return h, nil
}
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
if h.typeByte&0x18 != 0 {
return nil, errors.New("4th and 5th bit must be 0")
}
h.KeyPhase = int(h.typeByte&0x4) >> 2
if err := h.readPacketNumber(b); err != nil {
return nil, err
}
return h, nil
}
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(pn)
return nil
}
// Write writes the Header.
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
if h.IsLongHeader {
return h.writeLongHeader(b, ver)
}
return h.writeShortHeader(b, ver)
}
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
var packetType uint8
switch h.Type {
case protocol.PacketTypeInitial:
packetType = 0x0
case protocol.PacketType0RTT:
packetType = 0x1
case protocol.PacketTypeHandshake:
packetType = 0x2
case protocol.PacketTypeRetry:
packetType = 0x3
}
firstByte := 0xc0 | packetType<<4
if h.Type == protocol.PacketTypeRetry {
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
if err != nil {
return err
}
firstByte |= odcil
} else { // Retry packets don't have a packet number
firstByte |= uint8(h.PacketNumberLen - 1)
}
b.WriteByte(firstByte)
utils.BigEndian.WriteUint32(b, uint32(h.Version))
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
if err != nil {
return err
}
b.WriteByte(connIDLen)
b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes())
switch h.Type {
case protocol.PacketTypeRetry:
b.Write(h.OrigDestConnectionID.Bytes())
b.Write(h.Token)
return nil
case protocol.PacketTypeInitial:
utils.WriteVarInt(b, uint64(len(h.Token)))
b.Write(h.Token)
}
utils.WriteVarInt(b, uint64(h.Length))
return h.writePacketNumber(b)
}
// TODO: add support for the key phase
func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
typeByte := 0x40 | uint8(h.PacketNumberLen-1)
typeByte |= byte(h.KeyPhase << 2)
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
return h.writePacketNumber(b)
}
func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
if h.PacketNumberLen == protocol.PacketNumberLenInvalid || h.PacketNumberLen > protocol.PacketNumberLen4 {
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
utils.BigEndian.WriteUintN(b, uint8(h.PacketNumberLen), uint64(h.PacketNumber))
return nil
}
// GetLength determines the length of the Header.
func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
if h.IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
if h.Type == protocol.PacketTypeInitial {
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
}
return length
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
length += protocol.ByteCount(h.PacketNumberLen)
return length
}
// Log logs the Header
func (h *ExtendedHeader) Log(logger utils.Logger) {
if h.IsLongHeader {
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
return
}
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
}
}
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
dcil, err := encodeSingleConnIDLen(dest)
if err != nil {
return 0, err
}
scil, err := encodeSingleConnIDLen(src)
if err != nil {
return 0, err
}
return scil | dcil<<4, nil
}
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
len := id.Len()
if len == 0 {
return 0, nil
}
if len < 4 || len > 18 {
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
}
return byte(len - 3), nil
}

View File

@@ -2,150 +2,183 @@ package wire
import (
"bytes"
"crypto/rand"
"fmt"
"errors"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// Header is the header of a QUIC packet.
// The Header is the version independent part of the header
type Header struct {
Raw []byte
Version protocol.VersionNumber
SrcConnectionID protocol.ConnectionID
DestConnectionID protocol.ConnectionID
Version protocol.VersionNumber
DestConnectionID protocol.ConnectionID
SrcConnectionID protocol.ConnectionID
OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
IsVersionNegotiation bool
SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
Type protocol.PacketType
IsLongHeader bool
KeyPhase int
PayloadLen protocol.ByteCount
Token []byte
Type protocol.PacketType
Length protocol.ByteCount
Token []byte
SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet
OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet
typeByte byte
len int // how many bytes were read while parsing this header
}
// Write writes the Header.
func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
if h.IsLongHeader {
return h.writeLongHeader(b, ver)
// ParseHeader parses the header.
// For short header packets: up to the packet number.
// For long header packets:
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
startLen := b.Len()
h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
if err != nil {
return nil, err
}
return h.writeShortHeader(b, ver)
h.len = startLen - b.Len()
return h, nil
}
// TODO: add support for the key phase
func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
b.WriteByte(byte(0x80 | h.Type))
utils.BigEndian.WriteUint32(b, uint32(h.Version))
connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}
h := &Header{
typeByte: typeByte,
IsLongHeader: typeByte&0x80 > 0,
}
if !h.IsLongHeader {
if h.typeByte&0x40 == 0 {
return nil, errors.New("not a QUIC packet")
}
if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
return nil, err
}
return h, nil
}
if err := h.parseLongHeader(b); err != nil {
return nil, err
}
return h, nil
}
func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
var err error
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
return err
}
func (h *Header) parseLongHeader(b *bytes.Reader) error {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
}
b.WriteByte(connIDLen)
b.Write(h.DestConnectionID.Bytes())
b.Write(h.SrcConnectionID.Bytes())
if h.Type == protocol.PacketTypeInitial {
utils.WriteVarInt(b, uint64(len(h.Token)))
b.Write(h.Token)
h.Version = protocol.VersionNumber(v)
if !h.IsVersionNegotiation() && h.typeByte&0x40 == 0 {
return errors.New("not a QUIC packet")
}
if h.Type == protocol.PacketTypeRetry {
odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
if err != nil {
return err
}
// randomize the first 4 bits
odcilByte := make([]byte, 1)
_, _ = rand.Read(odcilByte) // it's safe to ignore the error here
odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
b.Write(odcilByte)
b.Write(h.OrigDestConnectionID.Bytes())
b.Write(h.Token)
connIDLenByte, err := b.ReadByte()
if err != nil {
return err
}
dcil, scil := decodeConnIDLen(connIDLenByte)
h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil)
if err != nil {
return err
}
h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil)
if err != nil {
return err
}
if h.Version == 0 {
return h.parseVersionNegotiationPacket(b)
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return nil
}
utils.WriteVarInt(b, uint64(h.PayloadLen))
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
switch (h.typeByte & 0x30) >> 4 {
case 0x0:
h.Type = protocol.PacketTypeInitial
case 0x1:
h.Type = protocol.PacketType0RTT
case 0x2:
h.Type = protocol.PacketTypeHandshake
case 0x3:
h.Type = protocol.PacketTypeRetry
}
func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
typeByte := byte(0x30)
typeByte |= byte(h.KeyPhase << 6)
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
// GetLength determines the length of the Header.
func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount {
if h.IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.PayloadLen))
if h.Type == protocol.PacketTypeInitial {
length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
if h.Type == protocol.PacketTypeRetry {
odcil := decodeSingleConnIDLen(h.typeByte & 0xf)
h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
if err != nil {
return err
}
return length
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
length += protocol.ByteCount(h.PacketNumberLen)
return length
}
// Log logs the Header
func (h *Header) Log(logger utils.Logger) {
if h.IsLongHeader {
if h.Version == 0 {
logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
} else {
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
return
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
h.Token = make([]byte, b.Len())
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
return nil
}
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := utils.ReadVarInt(b)
if err != nil {
return err
}
if tokenLen > uint64(b.Len()) {
return io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
}
pl, err := utils.ReadVarInt(b)
if err != nil {
return err
}
h.Length = protocol.ByteCount(pl)
return nil
}
func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
dcil, err := encodeSingleConnIDLen(dest)
if err != nil {
return 0, err
func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error {
if b.Len() == 0 {
return qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
}
scil, err := encodeSingleConnIDLen(src)
if err != nil {
return 0, err
h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
for i := 0; b.Len() > 0; i++ {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return qerr.InvalidVersionNegotiationPacket
}
h.SupportedVersions[i] = protocol.VersionNumber(v)
}
return scil | dcil<<4, nil
return nil
}
func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
len := id.Len()
if len == 0 {
return 0, nil
}
if len < 4 || len > 18 {
return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
}
return byte(len - 3), nil
// IsVersionNegotiation says if this a version negotiation packet
func (h *Header) IsVersionNegotiation() bool {
return h.IsLongHeader && h.Version == 0
}
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
return h.toExtendedHeader().parse(b, ver)
}
func (h *Header) toExtendedHeader() *ExtendedHeader {
return &ExtendedHeader{Header: *h}
}
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {

View File

@@ -1,167 +0,0 @@
package wire
import (
"bytes"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// The InvariantHeader is the version independent part of the header
type InvariantHeader struct {
IsLongHeader bool
Version protocol.VersionNumber
SrcConnectionID protocol.ConnectionID
DestConnectionID protocol.ConnectionID
typeByte byte
}
// ParseInvariantHeader parses the version independent part of the header
func ParseInvariantHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*InvariantHeader, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}
h := &InvariantHeader{typeByte: typeByte}
h.IsLongHeader = typeByte&0x80 > 0
// If this is not a Long Header, it could either be a Public Header or a Short Header.
if !h.IsLongHeader {
var err error
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
if err != nil {
return nil, err
}
return h, nil
}
// Long Header
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return nil, err
}
h.Version = protocol.VersionNumber(v)
connIDLenByte, err := b.ReadByte()
if err != nil {
return nil, err
}
dcil, scil := decodeConnIDLen(connIDLenByte)
h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil)
if err != nil {
return nil, err
}
h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil)
if err != nil {
return nil, err
}
return h, nil
}
// Parse parses the version dependent part of the header
func (iv *InvariantHeader) Parse(b *bytes.Reader, sentBy protocol.Perspective, ver protocol.VersionNumber) (*Header, error) {
if iv.IsLongHeader {
if iv.Version == 0 { // Version Negotiation Packet
return iv.parseVersionNegotiationPacket(b)
}
return iv.parseLongHeader(b, sentBy, ver)
}
return iv.parseShortHeader(b, ver)
}
func (iv *InvariantHeader) toHeader() *Header {
return &Header{
IsLongHeader: iv.IsLongHeader,
DestConnectionID: iv.DestConnectionID,
SrcConnectionID: iv.SrcConnectionID,
Version: iv.Version,
}
}
func (iv *InvariantHeader) parseVersionNegotiationPacket(b *bytes.Reader) (*Header, error) {
h := iv.toHeader()
if b.Len() == 0 {
return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
}
h.IsVersionNegotiation = true
h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
for i := 0; b.Len() > 0; i++ {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return nil, qerr.InvalidVersionNegotiationPacket
}
h.SupportedVersions[i] = protocol.VersionNumber(v)
}
return h, nil
}
func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader, sentBy protocol.Perspective, v protocol.VersionNumber) (*Header, error) {
h := iv.toHeader()
h.Type = protocol.PacketType(iv.typeByte & 0x7f)
if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake {
return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type))
}
if h.Type == protocol.PacketTypeRetry {
odcilByte, err := b.ReadByte()
if err != nil {
return nil, err
}
odcil := decodeSingleConnIDLen(odcilByte & 0xf)
h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
if err != nil {
return nil, err
}
h.Token = make([]byte, b.Len())
if _, err := io.ReadFull(b, h.Token); err != nil {
return nil, err
}
return h, nil
}
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := utils.ReadVarInt(b)
if err != nil {
return nil, err
}
if tokenLen > uint64(b.Len()) {
return nil, io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return nil, err
}
}
pl, err := utils.ReadVarInt(b)
if err != nil {
return nil, err
}
h.PayloadLen = protocol.ByteCount(pl)
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
h.PacketNumber = pn
h.PacketNumberLen = pnLen
return h, nil
}
func (iv *InvariantHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*Header, error) {
h := iv.toHeader()
h.KeyPhase = int(iv.typeByte&0x40) >> 6
pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
if err != nil {
return nil, err
}
h.PacketNumber = pn
h.PacketNumberLen = pnLen
return h, nil
}

View File

@@ -162,72 +162,46 @@ func (h *packetHandlerMap) listen() {
}
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
rcvTime := time.Now()
r := bytes.NewReader(data)
iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen)
hdr, err := wire.ParseHeader(r, h.connIDLen)
// drop the packet if we can't parse the header
if err != nil {
return fmt.Errorf("error parsing invariant header: %s", err)
}
h.mutex.RLock()
handlerEntry, handlerFound := h.handlers[string(iHdr.DestConnectionID)]
server := h.server
var sentBy protocol.Perspective
var version protocol.VersionNumber
var handlePacket func(*receivedPacket)
if handlerFound { // existing session
handler := handlerEntry.handler
sentBy = handler.GetPerspective().Opposite()
version = handler.GetVersion()
handlePacket = handler.handlePacket
} else { // no session found
// this might be a stateless reset
if !iHdr.IsLongHeader {
if len(data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
h.mutex.RUnlock()
sess.destroy(errors.New("received a stateless reset"))
return nil
}
}
// TODO(#943): send a stateless reset
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
if server == nil { // no server set
h.mutex.RUnlock()
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
handlePacket = server.handlePacket
sentBy = protocol.PerspectiveClient
version = iHdr.Version
}
h.mutex.RUnlock()
hdr, err := iHdr.Parse(r, sentBy, version)
if err != nil {
return fmt.Errorf("error parsing header: %s", err)
}
hdr.Raw = data[:len(data)-r.Len()]
packetData := data[len(data)-r.Len():]
if hdr.IsLongHeader {
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
}
packetData = packetData[:int(hdr.PayloadLen)]
// TODO(#1312): implement parsing of compound packets
p := &receivedPacket{
remoteAddr: addr,
hdr: hdr,
data: data,
rcvTime: time.Now(),
}
handlePacket(&receivedPacket{
remoteAddr: addr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
h.mutex.RLock()
defer h.mutex.RUnlock()
handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)]
if handlerFound { // existing session
handlerEntry.handler.handlePacket(p)
return nil
}
// No session found.
// This might be a stateless reset.
if !hdr.IsLongHeader {
if len(data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
sess.destroy(errors.New("received a stateless reset"))
return nil
}
}
// TODO(#943): send a stateless reset
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID)
}
if h.server == nil { // no server set
return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
}
h.server.handlePacket(p)
return nil
}

View File

@@ -25,7 +25,7 @@ type packer interface {
}
type packedPacket struct {
header *wire.Header
header *wire.ExtendedHeader
raw []byte
frames []wire.Frame
encryptionLevel protocol.EncryptionLevel
@@ -397,21 +397,20 @@ func (p *packetPacker) composeNextPacket(
return frames, nil
}
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header {
func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber()
header := &wire.Header{
PacketNumber: pn,
PacketNumberLen: pnLen,
Version: p.version,
DestConnectionID: p.destConnID,
}
header := &wire.ExtendedHeader{}
header.PacketNumber = pn
header.PacketNumberLen = pnLen
header.Version = p.version
header.DestConnectionID = p.destConnID
if encLevel != protocol.Encryption1RTT {
header.IsLongHeader = true
header.SrcConnectionID = p.srcConnID
// Set the payload len to maximum size.
// Set the length to the maximum packet size.
// Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.
header.PayloadLen = p.maxPacketSize
header.Length = p.maxPacketSize
switch encLevel {
case protocol.EncryptionInitial:
header.Type = protocol.PacketTypeInitial
@@ -424,8 +423,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
}
func (p *packetPacker) writeAndSealPacket(
header *wire.Header,
frames []wire.Frame,
header *wire.ExtendedHeader, frames []wire.Frame,
sealer handshake.Sealer,
) ([]byte, error) {
raw := *getPacketBuffer()
@@ -433,24 +431,24 @@ func (p *packetPacker) writeAndSealPacket(
addPadding := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial && !p.hasSentPacket
// the payload length is only needed for Long Headers
// the length is only needed for Long Headers
if header.IsLongHeader {
if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
header.Token = p.token
}
if addPadding {
headerLen := header.GetLength(p.version)
header.PayloadLen = protocol.ByteCount(protocol.MinInitialPacketSize) - headerLen
header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen
} else {
payloadLen := protocol.ByteCount(sealer.Overhead())
length := protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen)
for _, frame := range frames {
payloadLen += frame.Length(p.version)
length += frame.Length(p.version)
}
header.PayloadLen = payloadLen
header.Length = length
}
}
if err := header.Write(buffer, p.perspective, p.version); err != nil {
if err := header.Write(buffer, p.version); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()

View File

@@ -35,7 +35,7 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
}
}
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) {
func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
buf := *getPacketBuffer()
buf = buf[:0]
defer putPacketBuffer(&buf)

View File

@@ -97,7 +97,7 @@ func (s *sendStream) Write(p []byte) (int, error) {
s.dataForWriting = make([]byte, len(p))
copy(s.dataForWriting, p)
s.sender.onHasStreamData(s.streamID)
go s.sender.onHasStreamData(s.streamID)
var bytesWritten int
var err error
@@ -222,7 +222,7 @@ func (s *sendStream) Close() error {
return fmt.Errorf("Close called for canceled stream %d", s.streamID)
}
s.finishedWriting = true
s.sender.onHasStreamData(s.streamID) // need to send the FIN
go s.sender.onHasStreamData(s.streamID) // need to send the FIN
s.ctxCancel()
return nil
}
@@ -268,10 +268,14 @@ func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
func (s *sendStream) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) {
s.flowController.UpdateSendWindow(frame.ByteOffset)
s.mutex.Lock()
hasData := false
if s.dataForWriting != nil {
s.sender.onHasStreamData(s.streamID)
hasData = true
}
s.mutex.Unlock()
if hasData {
s.sender.onHasStreamData(s.streamID)
}
}
// must be called after locking the mutex

View File

@@ -21,7 +21,6 @@ type packetHandler interface {
handlePacket(*receivedPacket)
io.Closer
destroy(error)
GetVersion() protocol.VersionNumber
GetPerspective() protocol.Perspective
}
@@ -99,7 +98,8 @@ var _ Listener = &server{}
var _ unknownPacketHandler = &server{}
// ListenAddr creates a QUIC server listening on a given address.
// The tls.Config must not be nil, the quic.Config may be nil.
// The tls.Config must not be nil and must contain a certificate configuration.
// The quic.Config may be nil, in that case the default values will be used.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
@@ -118,7 +118,11 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err
}
// Listen listens for QUIC connections on a given net.PacketConn.
// The tls.Config must not be nil, the quic.Config may be nil.
// A single PacketConn only be used for a single call to Listen.
// The PacketConn can be used for simultaneous calls to Dial.
// QUIC connection IDs are used for demultiplexing the different connections.
// The tls.Config must not be nil and must contain a certificate configuration.
// The quic.Config may be nil, in that case the default values will be used.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
return listen(conn, tlsConf, config)
}
@@ -300,23 +304,17 @@ func (s *server) Addr() net.Addr {
}
func (s *server) handlePacket(p *receivedPacket) {
if err := s.handlePacketImpl(p); err != nil {
s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
}
}
func (s *server) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
hdr := p.hdr
// send a Version Negotiation Packet if the client is speaking a different protocol version
if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
return s.sendVersionNegotiationPacket(p)
go s.sendVersionNegotiationPacket(p)
return
}
if hdr.Type == protocol.PacketTypeInitial {
go s.handleInitial(p)
}
// TODO(#943): send Stateless Reset
return nil
}
func (s *server) handleInitial(p *receivedPacket) {
@@ -335,11 +333,11 @@ func (s *server) handleInitial(p *receivedPacket) {
}
func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
hdr := p.header
hdr := p.hdr
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
return nil, nil, errors.New("dropping Initial packet with too short connection ID")
}
if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize {
if len(p.data) < protocol.MinInitialPacketSize {
return nil, nil, errors.New("dropping too small Initial packet")
}
@@ -358,7 +356,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con
if !s.config.AcceptCookie(p.remoteAddr, cookie) {
// Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the session.
p.header.Log(s.logger)
(&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger)
return nil, nil, s.sendRetry(p.remoteAddr, hdr)
}
@@ -431,19 +429,18 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
if err != nil {
return err
}
replyHdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
Version: hdr.Version,
SrcConnectionID: connID,
DestConnectionID: hdr.SrcConnectionID,
OrigDestConnectionID: hdr.DestConnectionID,
Token: token,
}
replyHdr := &wire.ExtendedHeader{}
replyHdr.IsLongHeader = true
replyHdr.Type = protocol.PacketTypeRetry
replyHdr.Version = hdr.Version
replyHdr.SrcConnectionID = connID
replyHdr.DestConnectionID = hdr.SrcConnectionID
replyHdr.OrigDestConnectionID = hdr.DestConnectionID
replyHdr.Token = token
s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
replyHdr.Log(s.logger)
buf := &bytes.Buffer{}
if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil {
if err := replyHdr.Write(buf, hdr.Version); err != nil {
return err
}
if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
@@ -452,14 +449,15 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
return nil
}
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
hdr := p.header
s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
hdr := p.hdr
s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
if err != nil {
return err
s.logger.Debugf("Error composing Version Negotiation: %s", err)
return
}
if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)
}
_, err = s.conn.WriteTo(data, p.remoteAddr)
return err
}

View File

@@ -32,7 +32,7 @@ func (s *serverSession) handlePacket(p *receivedPacket) {
}
func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
hdr := p.hdr
// Probably an old packet that was sent by the client before the version was negotiated.
// It is safe to drop it.

View File

@@ -1,6 +1,7 @@
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
@@ -21,7 +22,7 @@ import (
)
type unpacker interface {
Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error)
Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
}
type streamGetter interface {
@@ -52,7 +53,7 @@ type cryptoStreamHandler interface {
type receivedPacket struct {
remoteAddr net.Addr
header *wire.Header
hdr *wire.Header
data []byte
rcvTime time.Time
}
@@ -113,7 +114,6 @@ type session struct {
receivedFirstPacket bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
receivedFirstForwardSecurePacket bool
lastRcvdPacketNumber protocol.PacketNumber
// Used to calculate the next packet number from the truncated wire
// representation, and sent back in public reset packets
largestRcvdPacketNumber protocol.PacketNumber
@@ -289,7 +289,7 @@ var newClientSession = func(
func (s *session) preSetup() {
s.rttStats = &congestion.RTTStats{}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version)
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.InitialMaxData,
@@ -374,7 +374,7 @@ runLoop:
}
// This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it.
putPacketBuffer(&p.header.Raw)
// TODO: putPacketBuffer(&p.extHdr.Raw)
case <-s.handshakeCompleteChan:
s.handleHandshakeComplete()
}
@@ -479,24 +479,41 @@ func (s *session) handleHandshakeComplete() {
}
func (s *session) handlePacketImpl(p *receivedPacket) error {
hdr := p.header
// The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.header.SrcConnectionID, s.destConnID)
if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID)
return nil
}
p.rcvTime = time.Now()
data := p.data
r := bytes.NewReader(data)
hdr, err := p.hdr.ParseExtended(r, s.version)
if err != nil {
return fmt.Errorf("error parsing extended header: %s", err)
}
hdr.Raw = data[:len(data)-r.Len()]
data = data[len(data)-r.Len():]
if hdr.IsLongHeader {
if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
}
if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length)
}
data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)]
// TODO(#1312): implement parsing of compound packets
}
// Calculate packet number
hdr.PacketNumber = protocol.InferPacketNumber(
hdr.PacketNumberLen,
s.largestRcvdPacketNumber,
hdr.PacketNumber,
s.version,
)
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, p.data)
packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
if s.logger.Debug() {
if err != nil {
s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID)
@@ -530,7 +547,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
}
}
s.lastRcvdPacketNumber = hdr.PacketNumber
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
@@ -543,10 +559,10 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
}
}
return s.handleFrames(packet.frames, packet.encryptionLevel)
return s.handleFrames(packet.frames, hdr.PacketNumber, packet.encryptionLevel)
}
func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error {
func (s *session) handleFrames(fs []wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
for _, ff := range fs {
var err error
wire.LogFrame(s.logger, ff, false)
@@ -556,7 +572,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
case *wire.StreamFrame:
err = s.handleStreamFrame(frame, encLevel)
case *wire.AckFrame:
err = s.handleAckFrame(frame, encLevel)
err = s.handleAckFrame(frame, pn, encLevel)
case *wire.ConnectionCloseFrame:
s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
case *wire.ResetStreamFrame:
@@ -702,8 +718,8 @@ func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) {
s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data})
}
func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error {
if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil {
func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastNetworkActivityTime); err != nil {
return err
}
s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
@@ -1065,14 +1081,14 @@ func (s *session) scheduleSending() {
func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) {
if s.handshakeComplete {
s.logger.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data))
s.logger.Debugf("Received undecryptable packet from %s after the handshake (%d bytes)", p.remoteAddr.String(), len(p.data))
return
}
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
s.logger.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber)
s.logger.Infof("Dropping undecrytable packet (%d bytes). Undecryptable packet queue full.", len(p.data))
return
}
s.logger.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber)
s.logger.Infof("Queueing packet (%d bytes) for later decryption", len(p.data))
s.undecryptablePackets = append(s.undecryptablePackets, p)
}