Compare commits

..

49 Commits

Author SHA1 Message Date
Darien Raymond
ce91e92435 fix #1496 2019-01-04 00:24:28 +01:00
Darien Raymond
163fe2523e fix buffer pool in quic 2019-01-02 13:13:50 +01:00
Darien Raymond
ec89d42feb sync quic package 2019-01-02 13:01:06 +01:00
Darien Raymond
d20f87da4b comments 2019-01-01 20:16:04 +01:00
Darien Raymond
309fb9c227 update licence 2019-01-01 11:33:21 +01:00
Darien Raymond
3de8389361 rename CloseError() to Interrupt() 2018-12-31 21:25:10 +01:00
Darien Raymond
d35c407419 fix #1493 2018-12-31 10:43:08 +01:00
Darien Raymond
1c830472b9 dns protocol package 2018-12-29 09:03:32 +01:00
Darien Raymond
daa8c9c5da rename NameServerInterface to Client 2018-12-28 20:28:31 +01:00
Darien Raymond
fc1e660c27 change net.IP to net.Address 2018-12-28 20:15:22 +01:00
Darien Raymond
7f1bd9f522 comment 2018-12-27 21:13:02 +01:00
Darien Raymond
41b1ac192e use compact in tls writer 2018-12-27 20:38:24 +01:00
Darien Raymond
c72d853454 return correct error 2018-12-27 19:59:49 +01:00
Darien Raymond
d670629651 more test case for tls sniffing 2018-12-27 19:41:23 +01:00
Darien Raymond
88e757e33f merge duplicated code 2018-12-27 17:00:34 +01:00
Darien Raymond
5ba41e50ec update mocks 2018-12-27 16:37:13 +01:00
Darien Raymond
fc92b6295a compact buffers 2018-12-27 16:36:48 +01:00
Darien Raymond
012a2d6f57 fix #1477 2018-12-17 20:31:54 +01:00
Darien Raymond
ac6c262f2e Update version 2018-12-12 18:37:51 +01:00
Darien Raymond
fd7bbdf5e9 update geosite on release 2018-12-11 16:45:20 +01:00
Darien Raymond
a138ab7cb3 update geosite.dat 2018-12-11 16:43:56 +01:00
Darien Raymond
ec95dca3e5 force packet reader in freedom on UDP 2018-12-11 10:17:50 +01:00
Darien Raymond
4055c467a1 tweak quic server parameters 2018-12-11 09:50:36 +01:00
Darien Raymond
0ca762e0e2 fix a deadlock in cacheReader. fix #1471 2018-12-11 09:17:10 +01:00
Darien Raymond
372da062d4 fix build break 2018-12-10 23:34:54 +01:00
Victoria Raymond
9d27d42139 Merge pull request #1470 from comwrg/fix-sniff-http-ipv6
fix sniff http ipv6 address
2018-12-10 23:10:02 +01:00
Victoria Raymond
ce412aec65 Merge branch 'master' into fix-sniff-http-ipv6 2018-12-10 23:09:55 +01:00
Darien Raymond
7e37d141e2 move parseHost to http protocol 2018-12-10 23:08:16 +01:00
comwrg
e52b387483 fix sniff http ipv6 address 2018-12-10 20:37:17 +08:00
Darien Raymond
867135d85a Merge branch 'master' of https://github.com/v2ray/v2ray-core 2018-12-07 18:02:23 +01:00
Victoria Raymond
86a407713e Merge pull request #1435 from sunshineplan/patch-1
Update install-release.sh
2018-12-07 16:09:40 +01:00
Darien Raymond
769eeb0efd remove plugin support as it is not practical 2018-12-07 09:50:11 +01:00
Darien Raymond
146b4eef0e delay between file uploading 2018-12-06 22:53:23 +01:00
Darien Raymond
93e375fa9a Update version 2018-12-06 22:05:47 +01:00
Darien Raymond
6b355ef461 fix a typo in local dns 2018-12-06 21:34:05 +01:00
Darien Raymond
f49f85b5bd fix #1462 2018-12-06 21:09:41 +01:00
Darien Raymond
21b3f66b8b fix IP parsing in local dns client 2018-12-06 20:11:45 +01:00
Darien Raymond
30b5bffad4 support custom log handler 2018-12-06 17:37:05 +01:00
Darien Raymond
b9450d8475 Revert "use default logger for android and ios"
This reverts commit 9743380e2d.
2018-12-06 17:03:15 +01:00
Darien Raymond
50e77cbb19 fix broken test 2018-12-06 14:44:24 +01:00
Darien Raymond
9743380e2d use default logger for android and ios 2018-12-06 14:40:45 +01:00
Darien Raymond
427679e66d simplify task execution 2018-12-06 11:35:02 +01:00
Darien Raymond
cf1705267e switch to errgroup 2018-12-06 10:22:14 +01:00
Darien Raymond
c89183e6b3 update port picking 2018-12-05 16:27:32 +01:00
Darien Raymond
4104a86b6c use default dns resolver to prevent errors in android 2018-12-05 15:48:40 +01:00
Darien Raymond
82d562d1f0 use session.Outbound.ResolvedIPs 2018-12-04 20:36:51 +01:00
Darien Raymond
98d89aebc2 fix release script 2018-12-04 19:40:12 +01:00
Darien Raymond
72f5e9de16 fix vendor directory 2018-12-04 18:59:14 +01:00
sunshineplan
34373cc1dc Update install-release.sh
由于稳定版本在版本末尾加.0了,所以检查更新的逻辑需要修改了
2018-11-26 15:23:38 +08:00
137 changed files with 13079 additions and 1218 deletions

View File

@@ -1,6 +1,6 @@
The MIT License (MIT)
Copyright (c) 2015-2018 V2Ray
Copyright (c) 2015-2019 V2Ray
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -8,7 +8,6 @@ import (
"v2ray.com/core/common/net"
"v2ray.com/core/common/signal/done"
"v2ray.com/core/transport"
"v2ray.com/core/transport/pipe"
)
// OutboundListener is a net.Listener for listening gRPC connections.
@@ -73,8 +72,8 @@ func (co *Outbound) Dispatch(ctx context.Context, link *transport.Link) {
co.access.RLock()
if co.closed {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
co.access.RUnlock()
return
}

View File

@@ -46,13 +46,22 @@ func (r *cachedReader) Cache(b *buf.Buffer) {
r.Unlock()
}
func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
func (r *cachedReader) readInternal() buf.MultiBuffer {
r.Lock()
defer r.Unlock()
if r.cache != nil && !r.cache.IsEmpty() {
mb := r.cache
r.cache = nil
return mb
}
return nil
}
func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
mb := r.readInternal()
if mb != nil {
return mb, nil
}
@@ -60,25 +69,21 @@ func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
}
func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
r.Lock()
defer r.Unlock()
if r.cache != nil && !r.cache.IsEmpty() {
mb := r.cache
r.cache = nil
mb := r.readInternal()
if mb != nil {
return mb, nil
}
return r.reader.ReadMultiBufferTimeout(timeout)
}
func (r *cachedReader) CloseError() {
func (r *cachedReader) Interrupt() {
r.Lock()
if r.cache != nil {
r.cache = buf.ReleaseMulti(r.cache)
}
r.Unlock()
r.reader.CloseError()
r.reader.Interrupt()
}
// DefaultDispatcher is a default implementation of Dispatcher.
@@ -258,5 +263,13 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx))
}
}
if dispatcher == nil {
newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx))
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
}
dispatcher.Dispatch(ctx, link)
}

View File

@@ -4,7 +4,6 @@ import (
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/features/stats"
"v2ray.com/core/transport/pipe"
)
type SizeStatWriter struct {
@@ -21,6 +20,6 @@ func (w *SizeStatWriter) Close() error {
return common.Close(w.Writer)
}
func (w *SizeStatWriter) CloseError() {
pipe.CloseError(w.Writer)
func (w *SizeStatWriter) Interrupt() {
common.Interrupt(w.Writer)
}

View File

@@ -9,7 +9,7 @@ import (
// StaticHosts represents static domain-ip mapping in DNS server.
type StaticHosts struct {
ips [][]net.IP
ips [][]net.Address
matchers *strmatcher.MatcherGroup
}
@@ -36,7 +36,7 @@ func toStrMatcher(t DomainMatchingType, domain string) (strmatcher.Matcher, erro
func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDomain) (*StaticHosts, error) {
g := new(strmatcher.MatcherGroup)
sh := &StaticHosts{
ips: make([][]net.IP, len(hosts)+len(legacy)+16),
ips: make([][]net.Address, len(hosts)+len(legacy)+16),
matchers: g,
}
@@ -50,10 +50,10 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma
address := ip.AsAddress()
if address.Family().IsDomain() {
return nil, newError("ignoring domain address in static hosts: ", address.Domain()).AtWarning()
return nil, newError("invalid domain address in static hosts: ", address.Domain()).AtWarning()
}
sh.ips[id] = []net.IP{address.IP()}
sh.ips[id] = []net.Address{address}
}
}
@@ -63,9 +63,13 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma
return nil, newError("failed to create domain matcher").Base(err)
}
id := g.Add(matcher)
ips := make([]net.IP, len(mapping.Ip))
for idx, ip := range mapping.Ip {
ips[idx] = net.IP(ip)
ips := make([]net.Address, 0, len(mapping.Ip))
for _, ip := range mapping.Ip {
addr := net.IPAddress(ip)
if addr == nil {
return nil, newError("invalid IP address in static hosts: ", ip).AtWarning()
}
ips = append(ips, addr)
}
sh.ips[id] = ips
}
@@ -73,12 +77,11 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma
return sh, nil
}
func filterIP(ips []net.IP, option IPOption) []net.IP {
func filterIP(ips []net.Address, option IPOption) []net.IP {
filtered := make([]net.IP, 0, len(ips))
for _, ip := range ips {
parsed := net.IPAddress(ip)
if (parsed.Family().IsIPv4() && option.IPv4Enable) || (parsed.Family().IsIPv6() && option.IPv6Enable) {
filtered = append(filtered, parsed.IP())
if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) {
filtered = append(filtered, ip.IP())
}
}
if len(filtered) == 0 {

View File

@@ -7,12 +7,13 @@ import (
"v2ray.com/core/features/dns/localdns"
)
// IPOption is an object for IP query options.
type IPOption struct {
IPv4Enable bool
IPv6Enable bool
}
type NameServerInterface interface {
type Client interface {
Name() string
QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error)
}

View File

@@ -20,7 +20,7 @@ import (
type Server struct {
sync.Mutex
hosts *StaticHosts
servers []NameServerInterface
clients []Client
clientIP net.IP
domainMatcher strmatcher.IndexMatcher
domainIndexMap map[uint32]uint32
@@ -29,7 +29,7 @@ type Server struct {
// New creates a new DNS server with given configuration.
func New(ctx context.Context, config *Config) (*Server, error) {
server := &Server{
servers: make([]NameServerInterface, 0, len(config.NameServers)+len(config.NameServer)),
clients: make([]Client, 0, len(config.NameServers)+len(config.NameServer)),
}
if len(config.ClientIp) > 0 {
if len(config.ClientIp) != 4 && len(config.ClientIp) != 16 {
@@ -47,22 +47,22 @@ func New(ctx context.Context, config *Config) (*Server, error) {
addNameServer := func(endpoint *net.Endpoint) int {
address := endpoint.Address.AsAddress()
if address.Family().IsDomain() && address.Domain() == "localhost" {
server.servers = append(server.servers, NewLocalNameServer())
server.clients = append(server.clients, NewLocalNameServer())
} else {
dest := endpoint.AsDestination()
if dest.Network == net.Network_Unknown {
dest.Network = net.Network_UDP
}
if dest.Network == net.Network_UDP {
idx := len(server.servers)
server.servers = append(server.servers, nil)
idx := len(server.clients)
server.clients = append(server.clients, nil)
common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
server.servers[idx] = NewClassicNameServer(dest, d, server.clientIP)
server.clients[idx] = NewClassicNameServer(dest, d, server.clientIP)
}))
}
}
return len(server.servers) - 1
return len(server.clients) - 1
}
if len(config.NameServers) > 0 {
@@ -94,8 +94,8 @@ func New(ctx context.Context, config *Config) (*Server, error) {
server.domainIndexMap = domainIndexMap
}
if len(server.servers) == 0 {
server.servers = append(server.servers, NewLocalNameServer())
if len(server.clients) == 0 {
server.clients = append(server.clients, NewLocalNameServer())
}
return server, nil
@@ -116,9 +116,9 @@ func (s *Server) Close() error {
return nil
}
func (s *Server) queryIPTimeout(server NameServerInterface, domain string, option IPOption) ([]net.IP, error) {
func (s *Server) queryIPTimeout(client Client, domain string, option IPOption) ([]net.IP, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*4)
ips, err := server.QueryIP(ctx, domain, option)
ips, err := client.QueryIP(ctx, domain, option)
cancel()
return ips, err
}
@@ -156,7 +156,7 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
if s.domainMatcher != nil {
idx := s.domainMatcher.Match(domain)
if idx > 0 {
ns := s.servers[s.domainIndexMap[idx]]
ns := s.clients[s.domainIndexMap[idx]]
ips, err := s.queryIPTimeout(ns, domain, option)
if len(ips) > 0 {
return ips, nil
@@ -168,13 +168,13 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
}
}
for _, server := range s.servers {
ips, err := s.queryIPTimeout(server, domain, option)
for _, client := range s.clients {
ips, err := s.queryIPTimeout(client, domain, option)
if len(ips) > 0 {
return ips, nil
}
if err != nil {
newError("failed to lookup ip for domain ", domain, " at server ", server.Name()).Base(err).WriteToLog()
newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
lastErr = err
}
}

View File

@@ -8,10 +8,10 @@ import (
"time"
"golang.org/x/net/dns/dnsmessage"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol/dns"
"v2ray.com/core/common/session"
"v2ray.com/core/common/signal/pubsub"
"v2ray.com/core/common/task"
@@ -20,7 +20,7 @@ import (
)
type IPRecord struct {
IP net.IP
IP net.Address
Expire time.Time
}
@@ -149,7 +149,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buf
break
}
ips = append(ips, IPRecord{
IP: net.IP(ans.A[:]),
IP: net.IPAddress(ans.A[:]),
Expire: now.Add(time.Duration(ttl) * time.Second),
})
case dnsmessage.TypeAAAA:
@@ -159,7 +159,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buf
break
}
ips = append(ips, IPRecord{
IP: net.IP(ans.AAAA[:]),
IP: net.IPAddress(ans.AAAA[:]),
Expire: now.Add(time.Duration(ttl) * time.Second),
})
default:
@@ -293,25 +293,13 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
return msgs
}
func msgToBuffer2(msg *dnsmessage.Message) (*buf.Buffer, error) {
buffer := buf.New()
rawBytes := buffer.Extend(buf.Size)
packed, err := msg.AppendPack(rawBytes[:0])
if err != nil {
buffer.Release()
return nil, err
}
buffer.Resize(0, int32(len(packed)))
return buffer, nil
}
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
msgs := s.buildMsgs(domain, option)
for _, msg := range msgs {
b, err := msgToBuffer2(msg)
b, err := dns.PackMessage(msg)
common.Must(err)
s.udpServer.Dispatch(context.Background(), s.address, b)
}
@@ -323,7 +311,7 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []n
s.RUnlock()
if found && len(records) > 0 {
var ips []net.IP
var ips []net.Address
now := time.Now()
for _, rec := range records {
if rec.Expire.After(now) {

View File

@@ -31,32 +31,24 @@ func New(ctx context.Context, config *Config) (*Instance, error) {
}
func (g *Instance) initAccessLogger() error {
switch g.config.AccessLogType {
case LogType_File:
creator, err := log.CreateFileLogWriter(g.config.AccessLogPath)
if err != nil {
return err
}
g.accessLogger = log.NewLogger(creator)
case LogType_Console:
g.accessLogger = log.NewLogger(log.CreateStdoutLogWriter())
default:
handler, err := createHandler(g.config.AccessLogType, HandlerCreatorOptions{
Path: g.config.AccessLogPath,
})
if err != nil {
return err
}
g.accessLogger = handler
return nil
}
func (g *Instance) initErrorLogger() error {
switch g.config.ErrorLogType {
case LogType_File:
creator, err := log.CreateFileLogWriter(g.config.ErrorLogPath)
if err != nil {
return err
}
g.errorLogger = log.NewLogger(creator)
case LogType_Console:
g.errorLogger = log.NewLogger(log.CreateStdoutLogWriter())
default:
handler, err := createHandler(g.config.ErrorLogType, HandlerCreatorOptions{
Path: g.config.ErrorLogPath,
})
if err != nil {
return err
}
g.errorLogger = handler
return nil
}

51
app/log/log_creator.go Normal file
View File

@@ -0,0 +1,51 @@
package log
import (
"v2ray.com/core/common"
"v2ray.com/core/common/log"
)
type HandlerCreatorOptions struct {
Path string
}
type HandlerCreator func(LogType, HandlerCreatorOptions) (log.Handler, error)
var (
handlerCreatorMap = make(map[LogType]HandlerCreator)
)
func RegisterHandlerCreator(logType LogType, f HandlerCreator) error {
if f == nil {
return newError("nil HandlerCreator")
}
handlerCreatorMap[logType] = f
return nil
}
func createHandler(logType LogType, options HandlerCreatorOptions) (log.Handler, error) {
creator, found := handlerCreatorMap[logType]
if !found {
return nil, newError("unable to create log handler for ", logType)
}
return creator(logType, options)
}
func init() {
common.Must(RegisterHandlerCreator(LogType_Console, func(lt LogType, options HandlerCreatorOptions) (log.Handler, error) {
return log.NewLogger(log.CreateStdoutLogWriter()), nil
}))
common.Must(RegisterHandlerCreator(LogType_File, func(lt LogType, options HandlerCreatorOptions) (log.Handler, error) {
creator, err := log.CreateFileLogWriter(options.Path)
if err != nil {
return nil, err
}
return log.NewLogger(creator), nil
}))
common.Must(RegisterHandlerCreator(LogType_None, func(lt LogType, options HandlerCreatorOptions) (log.Handler, error) {
return nil, nil
}))
}

52
app/log/log_test.go Normal file
View File

@@ -0,0 +1,52 @@
package log_test
import (
"context"
"testing"
"github.com/golang/mock/gomock"
"v2ray.com/core/app/log"
"v2ray.com/core/common"
clog "v2ray.com/core/common/log"
"v2ray.com/core/testing/mocks"
)
func TestCustomLogHandler(t *testing.T) {
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
var loggedValue []string
mockHandler := mocks.NewLogHandler(mockCtl)
mockHandler.EXPECT().Handle(gomock.Any()).AnyTimes().DoAndReturn(func(msg clog.Message) {
loggedValue = append(loggedValue, msg.String())
})
log.RegisterHandlerCreator(log.LogType_Console, func(lt log.LogType, options log.HandlerCreatorOptions) (clog.Handler, error) {
return mockHandler, nil
})
logger, err := log.New(context.Background(), &log.Config{
ErrorLogLevel: clog.Severity_Debug,
ErrorLogType: log.LogType_Console,
AccessLogType: log.LogType_None,
})
common.Must(err)
common.Must(logger.Start())
clog.Record(&clog.GeneralMessage{
Severity: clog.Severity_Debug,
Content: "test",
})
if len(loggedValue) < 2 {
t.Fatal("expected 2 log messages, but actually ", loggedValue)
}
if loggedValue[1] != "[Debug] test" {
t.Fatal("expected '[Debug] test', but actually ", loggedValue[1])
}
common.Must(logger.Close())
}

View File

@@ -100,17 +100,17 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
if h.mux != nil {
if err := h.mux.Dispatch(ctx, link); err != nil {
newError("failed to process mux outbound traffic").Base(err).WriteToLog(session.ExportIDToError(ctx))
pipe.CloseError(link.Writer)
common.Interrupt(link.Writer)
}
} else {
if err := h.proxy.Process(ctx, link, h); err != nil {
// Ensure outbound ray is properly closed.
newError("failed to process outbound traffic").Base(err).WriteToLog(session.ExportIDToError(ctx))
pipe.CloseError(link.Writer)
common.Interrupt(link.Writer)
} else {
common.Must(common.Close(link.Writer))
}
pipe.CloseError(link.Reader)
common.Interrupt(link.Reader)
}
}

View File

@@ -97,7 +97,7 @@ func (o *Outbound) Tag() string {
func (o *Outbound) Dispatch(ctx context.Context, link *transport.Link) {
if err := o.portal.HandleConnection(ctx, link); err != nil {
newError("failed to process reverse connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
pipe.CloseError(link.Writer)
common.Interrupt(link.Writer)
}
}
@@ -244,7 +244,7 @@ func (w *PortalWorker) heartbeat() error {
defer func() {
common.Close(w.writer)
pipe.CloseError(w.reader)
common.Interrupt(w.reader)
w.writer = nil
}()
}

View File

@@ -111,9 +111,18 @@ func targetFromContent(ctx context.Context) net.Destination {
return outbound.Target
}
func resolvedIPFromContext(ctx context.Context) []net.IP {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
return nil
}
return outbound.ResolvedIPs
}
type MultiGeoIPMatcher struct {
matchers []*GeoIPMatcher
destFunc func(context.Context) net.Destination
matchers []*GeoIPMatcher
destFunc func(context.Context) net.Destination
resolvedIPFunc func(context.Context) []net.IP
}
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
@@ -126,17 +135,18 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
matchers = append(matchers, matcher)
}
var destFunc func(context.Context) net.Destination
if onSource {
destFunc = sourceFromContext
} else {
destFunc = targetFromContent
matcher := &MultiGeoIPMatcher{
matchers: matchers,
}
return &MultiGeoIPMatcher{
matchers: matchers,
destFunc: destFunc,
}, nil
if onSource {
matcher.destFunc = sourceFromContext
} else {
matcher.destFunc = targetFromContent
matcher.resolvedIPFunc = resolvedIPFromContext
}
return matcher, nil
}
func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
@@ -146,10 +156,12 @@ func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
if dest.IsValid() && dest.Address.Family().IsIP() {
ips = append(ips, dest.Address.IP())
} else if resolver, ok := ResolvedIPsFromContext(ctx); ok {
resolvedIPs := resolver.Resolve()
for _, rip := range resolvedIPs {
ips = append(ips, rip.IP())
}
if m.resolvedIPFunc != nil {
rips := m.resolvedIPFunc(ctx)
if len(rips) > 0 {
ips = append(ips, rips...)
}
}

View File

@@ -94,7 +94,8 @@ type Domain struct {
// Domain matching type.
Type Domain_Type `protobuf:"varint,1,opt,name=type,proto3,enum=v2ray.core.app.router.Domain_Type" json:"type,omitempty"`
// Domain value.
Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
// Attributes of this domain. May be used for filtering.
Attribute []*Domain_Attribute `protobuf:"bytes,3,rep,name=attribute,proto3" json:"attribute,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`

View File

@@ -7,32 +7,12 @@ import (
"v2ray.com/core"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/session"
"v2ray.com/core/features/dns"
"v2ray.com/core/features/outbound"
"v2ray.com/core/features/routing"
)
type key uint32
const (
resolvedIPsKey key = iota
)
type IPResolver interface {
Resolve() []net.Address
}
func ContextWithResolveIPs(ctx context.Context, f IPResolver) context.Context {
return context.WithValue(ctx, resolvedIPsKey, f)
}
func ResolvedIPsFromContext(ctx context.Context) (IPResolver, bool) {
ips, ok := ctx.Value(resolvedIPsKey).(IPResolver)
return ips, ok
}
func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
r := new(Router)
@@ -91,34 +71,6 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error
return nil
}
type ipResolver struct {
dns dns.Client
ip []net.Address
domain string
resolved bool
}
func (r *ipResolver) Resolve() []net.Address {
if r.resolved {
return r.ip
}
newError("looking for IP for domain: ", r.domain).WriteToLog()
r.resolved = true
ips, err := r.dns.LookupIP(r.domain)
if err != nil {
newError("failed to get IP address").Base(err).WriteToLog()
}
if len(ips) == 0 {
return nil
}
r.ip = make([]net.Address, len(ips))
for i, ip := range ips {
r.ip[i] = net.IPAddress(ip)
}
return r.ip
}
func (r *Router) PickRoute(ctx context.Context) (string, error) {
rule, err := r.pickRouteInternal(ctx)
if err != nil {
@@ -127,17 +79,27 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) {
return rule.GetTag()
}
// PickRoute implements routing.Router.
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
resolver := &ipResolver{
dns: r.dns,
func isDomainOutbound(outbound *session.Outbound) bool {
return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
}
func (r *Router) resolveIP(outbound *session.Outbound) error {
domain := outbound.Target.Address.Domain()
ips, err := r.dns.LookupIP(domain)
if err != nil {
return err
}
outbound.ResolvedIPs = ips
return nil
}
// PickRoute implements routing.Router.
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
outbound := session.OutboundFromContext(ctx)
if r.domainStrategy == Config_IpOnDemand {
if outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() {
resolver.domain = outbound.Target.Address.Domain()
ctx = ContextWithResolveIPs(ctx, resolver)
if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) {
if err := r.resolveIP(outbound); err != nil {
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
}
@@ -147,21 +109,19 @@ func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
}
}
if outbound == nil || !outbound.Target.IsValid() {
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) {
return nil, common.ErrNoClue
}
dest := outbound.Target
if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() {
resolver.domain = dest.Address.Domain()
ips := resolver.Resolve()
if len(ips) > 0 {
ctx = ContextWithResolveIPs(ctx, resolver)
for _, rule := range r.rules {
if rule.Apply(ctx) {
return rule, nil
}
}
if err := r.resolveIP(outbound); err != nil {
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
return nil, common.ErrNoClue
}
// Try applying rules again if we have IPs.
for _, rule := range r.rules {
if rule.Apply(ctx) {
return rule, nil
}
}

View File

@@ -125,3 +125,72 @@ func TestIPOnDemand(t *testing.T) {
t.Error("expect tag 'test', bug actually ", tag)
}
}
func TestIPIfNonMatchDomain(t *testing.T) {
config := &Config{
DomainStrategy: Config_IpIfNonMatch,
Rule: []*RoutingRule{
{
TargetTag: &RoutingRule_Tag{
Tag: "test",
},
Cidr: []*CIDR{
{
Ip: []byte{192, 168, 0, 0},
Prefix: 16,
},
},
},
},
}
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl)
mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes()
r := new(Router)
common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)
}
}
func TestIPIfNonMatchIP(t *testing.T) {
config := &Config{
DomainStrategy: Config_IpIfNonMatch,
Rule: []*RoutingRule{
{
TargetTag: &RoutingRule_Tag{
Tag: "test",
},
Cidr: []*CIDR{
{
Ip: []byte{127, 0, 0, 0},
Prefix: 8,
},
},
},
},
}
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl)
r := new(Router)
common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)
}
}

View File

@@ -3,6 +3,7 @@ package buf
import (
"io"
"v2ray.com/core/common"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/serial"
)
@@ -121,6 +122,30 @@ func SplitBytes(mb MultiBuffer, b []byte) (MultiBuffer, int) {
return mb, totalBytes
}
// Compact returns another MultiBuffer by merging all content of the given one together.
func Compact(mb MultiBuffer) MultiBuffer {
if len(mb) == 0 {
return mb
}
mb2 := make(MultiBuffer, 0, len(mb))
last := mb[0]
for i := 1; i < len(mb); i++ {
curr := mb[i]
if last.Len()+curr.Len() > Size {
mb2 = append(mb2, last)
last = curr
} else {
common.Must2(last.ReadFrom(curr))
curr.Release()
}
}
mb2 = append(mb2, last)
return mb2
}
// SplitFirst splits the first Buffer from the beginning of the MultiBuffer.
func SplitFirst(mb MultiBuffer) (MultiBuffer, *Buffer) {
if len(mb) == 0 {
@@ -167,6 +192,25 @@ func SplitSize(mb MultiBuffer, size int32) (MultiBuffer, MultiBuffer) {
return mb, r
}
// WriteMultiBuffer writes all buffers from the MultiBuffer to the Writer one by one, and return error if any, with leftover MultiBuffer.
func WriteMultiBuffer(writer io.Writer, mb MultiBuffer) (MultiBuffer, error) {
for {
mb2, b := SplitFirst(mb)
mb = mb2
if b == nil {
break
}
_, err := writer.Write(b.Bytes())
b.Release()
if err != nil {
return mb, err
}
}
return nil, nil
}
// Len returns the total number of bytes in the MultiBuffer.
func (mb MultiBuffer) Len() int32 {
if mb == nil {

View File

@@ -148,12 +148,13 @@ func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) {
return nBytes, err
}
// Interrupt implements common.Interruptible.
func (r *BufferedReader) Interrupt() {
common.Interrupt(r.Reader)
}
// Close implements io.Closer.
func (r *BufferedReader) Close() error {
if !r.Buffer.IsEmpty() {
ReleaseMulti(r.Buffer)
r.Buffer = nil
}
return common.Close(r.Reader)
}

View File

@@ -7,6 +7,7 @@ import (
"net"
"testing"
"golang.org/x/sync/errgroup"
"v2ray.com/core/common"
. "v2ray.com/core/common/buf"
"v2ray.com/core/common/compare"
@@ -31,12 +32,17 @@ func TestReadvReader(t *testing.T) {
data := make([]byte, 8192)
common.Must2(rand.Read(data))
go func() {
var errg errgroup.Group
errg.Go(func() error {
writer := NewWriter(conn)
mb := MergeBytes(nil, data)
if err := writer.WriteMultiBuffer(mb); err != nil {
t.Fatal("failed to write data: ", err)
return writer.WriteMultiBuffer(mb)
})
defer func() {
if err := errg.Wait(); err != nil {
t.Error(err)
}
}()

View File

@@ -219,19 +219,9 @@ type SequentialWriter struct {
// WriteMultiBuffer implements Writer.
func (w *SequentialWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer ReleaseMulti(mb)
for _, b := range mb {
if b.IsEmpty() {
continue
}
if err := WriteAllBytes(w.Writer, b.Bytes()); err != nil {
return err
}
}
return nil
mb, err := WriteMultiBuffer(w.Writer, mb)
ReleaseMulti(mb)
return err
}
type noOpWriter byte

View File

@@ -1,12 +1,25 @@
package common
import "v2ray.com/core/common/errors"
// Closable is the interface for objects that can release its resources.
//
// v2ray:api:beta
type Closable interface {
// Close release all resources used by this object, including goroutines.
Close() error
}
// Interruptible is an interface for objects that can be stopped before its completion.
//
// v2ray:api:beta
type Interruptible interface {
Interrupt()
}
// Close closes the obj if it is a Closable.
//
// v2ray:api:beta
func Close(obj interface{}) error {
if c, ok := obj.(Closable); ok {
return c.Close()
@@ -14,6 +27,17 @@ func Close(obj interface{}) error {
return nil
}
// Interrupt calls Interrupt() if object implements Interruptible interface, or Close() if the object implements Closable interface.
//
// v2ray:api:beta
func Interrupt(obj interface{}) error {
if c, ok := obj.(Interruptible); ok {
c.Interrupt()
return nil
}
return Close(obj)
}
// Runnable is the interface for objects that can start to work and stop on demand.
type Runnable interface {
// Start starts the runnable object. Upon the method returning nil, the object begins to function properly.
@@ -29,15 +53,16 @@ type HasType interface {
Type() interface{}
}
// ChainedClosable is a Closable that consists of multiple Closable objects.
type ChainedClosable []Closable
func NewChainedClosable(c ...Closable) ChainedClosable {
return ChainedClosable(c)
}
// Close implements Closable.
func (cc ChainedClosable) Close() error {
var errs []error
for _, c := range cc {
c.Close()
if err := c.Close(); err != nil {
errs = append(errs, err)
}
}
return nil
return errors.Combine(errs...)
}

View File

@@ -213,8 +213,8 @@ func (m *ClientWorker) monitor() {
select {
case <-m.done.Wait():
m.sessionManager.Close()
common.Close(m.link.Writer) // nolint: errcheck
pipe.CloseError(m.link.Reader) // nolint: errcheck
common.Close(m.link.Writer) // nolint: errcheck
common.Interrupt(m.link.Reader) // nolint: errcheck
return
case <-timer.C:
size := m.sessionManager.Size()
@@ -253,14 +253,14 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
if err := writeFirstPayload(s.input, writer); err != nil {
newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
pipe.CloseError(s.input)
common.Interrupt(s.input)
return
}
if err := buf.Copy(s.input, writer); err != nil {
newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
pipe.CloseError(s.input)
common.Interrupt(s.input)
return
}
}
@@ -339,7 +339,7 @@ func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere
closingWriter.Close()
drainErr := buf.Copy(rr, buf.Discard)
pipe.CloseError(s.input)
common.Interrupt(s.input)
s.Close()
return drainErr
}
@@ -350,8 +350,8 @@ func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere
func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := m.sessionManager.Get(meta.SessionID); found {
if meta.Option.Has(OptionError) {
pipe.CloseError(s.input)
pipe.CloseError(s.output)
common.Interrupt(s.input)
common.Interrupt(s.output)
}
s.Close()
}

View File

@@ -5,6 +5,7 @@ import (
"io"
"v2ray.com/core"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/log"
@@ -146,7 +147,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
rr := s.NewReader(reader)
if err := buf.Copy(rr, s.output); err != nil {
buf.Copy(rr, buf.Discard)
pipe.CloseError(s.input)
common.Interrupt(s.input)
return s.Close()
}
return nil
@@ -177,7 +178,7 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere
closingWriter.Close()
drainErr := buf.Copy(rr, buf.Discard)
pipe.CloseError(s.input)
common.Interrupt(s.input)
s.Close()
return drainErr
}
@@ -188,8 +189,8 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere
func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := w.sessionManager.Get(meta.SessionID); found {
if meta.Option.Has(OptionError) {
pipe.CloseError(s.input)
pipe.CloseError(s.output)
common.Interrupt(s.input)
common.Interrupt(s.output)
}
s.Close()
}
@@ -241,7 +242,7 @@ func (w *ServerWorker) run(ctx context.Context) {
if err != nil {
if errors.Cause(err) != io.EOF {
newError("unexpected EOF").Base(err).WriteToLog(session.ExportIDToError(ctx))
pipe.CloseError(input)
common.Interrupt(input)
}
return
}

View File

@@ -115,7 +115,7 @@ func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
// Close implements net.Conn.Close().
func (c *connection) Close() error {
common.Must(c.done.Close())
common.Close(c.reader)
common.Interrupt(c.reader)
common.Close(c.writer)
if c.onClose != nil {
return c.onClose.Close()

77
common/protocol/dns/io.go Normal file
View File

@@ -0,0 +1,77 @@
package dns
import (
"sync"
"golang.org/x/net/dns/dnsmessage"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
)
func PackMessage(msg *dnsmessage.Message) (*buf.Buffer, error) {
buffer := buf.New()
rawBytes := buffer.Extend(buf.Size)
packed, err := msg.AppendPack(rawBytes[:0])
if err != nil {
buffer.Release()
return nil, err
}
buffer.Resize(0, int32(len(packed)))
return buffer, nil
}
type MessageReader interface {
ReadMessage() (*buf.Buffer, error)
}
type UDPReader struct {
buf.Reader
access sync.Mutex
cache buf.MultiBuffer
}
func (r *UDPReader) readCache() *buf.Buffer {
r.access.Lock()
defer r.access.Unlock()
mb, b := buf.SplitFirst(r.cache)
r.cache = mb
return b
}
func (r *UDPReader) refill() error {
mb, err := r.Reader.ReadMultiBuffer()
if err != nil {
return err
}
r.access.Lock()
r.cache = mb
r.access.Unlock()
return nil
}
// ReadMessage implements MessageReader.
func (r *UDPReader) ReadMessage() (*buf.Buffer, error) {
for {
b := r.readCache()
if b != nil {
return b, nil
}
if err := r.refill(); err != nil {
return nil, err
}
}
}
// Close implements common.Closable.
func (r *UDPReader) Close() error {
defer func() {
r.access.Lock()
buf.ReleaseMulti(r.cache)
r.cache = nil
r.access.Unlock()
}()
return common.Close(r.Reader)
}

View File

@@ -2,6 +2,7 @@ package http
import (
"net/http"
"strconv"
"strings"
"v2ray.com/core/common/net"
@@ -42,3 +43,24 @@ func RemoveHopByHopHeaders(header http.Header) {
header.Del(strings.TrimSpace(h))
}
}
// ParseHost splits host and port from a raw string. Default port is used when raw string doesn't contain port.
func ParseHost(rawHost string, defaultPort net.Port) (net.Destination, error) {
port := defaultPort
host, rawPort, err := net.SplitHostPort(rawHost)
if err != nil {
if addrError, ok := err.(*net.AddrError); ok && strings.Contains(addrError.Err, "missing port") {
host = rawHost
} else {
return net.Destination{}, err
}
} else if len(rawPort) > 0 {
intPort, err := strconv.Atoi(rawPort)
if err != nil {
return net.Destination{}, err
}
port = net.Port(intPort)
}
return net.TCPDestination(net.ParseAddress(host), port), nil
}

View File

@@ -6,6 +6,8 @@ import (
"strings"
"testing"
"v2ray.com/core/common/net"
. "v2ray.com/core/common/protocol/http"
. "v2ray.com/ext/assert"
)
@@ -53,3 +55,41 @@ Accept-Language: de,en;q=0.7,en-us;q=0.3
assert(req.Header.Get("Proxy-Connection"), IsEmpty)
assert(req.Header.Get("Proxy-Authenticate"), IsEmpty)
}
func TestParseHost(t *testing.T) {
testCases := []struct {
RawHost string
DefaultPort net.Port
Destination net.Destination
Error bool
}{
{
RawHost: "v2ray.com:80",
DefaultPort: 443,
Destination: net.TCPDestination(net.DomainAddress("v2ray.com"), 80),
},
{
RawHost: "tls.v2ray.com",
DefaultPort: 443,
Destination: net.TCPDestination(net.DomainAddress("tls.v2ray.com"), 443),
},
{
RawHost: "[2401:1bc0:51f0:ec08::1]:80",
DefaultPort: 443,
Destination: net.TCPDestination(net.ParseAddress("[2401:1bc0:51f0:ec08::1]"), 80),
},
}
for _, testCase := range testCases {
dest, err := ParseHost(testCase.RawHost, testCase.DefaultPort)
if testCase.Error {
if err == nil {
t.Error("for test case: ", testCase.RawHost, " expected error, but actually nil")
}
} else {
if dest != testCase.Destination {
t.Error("for test case: ", testCase.RawHost, " expected host: ", testCase.Destination.String(), " but got ", dest.String())
}
}
}
}

View File

@@ -6,6 +6,7 @@ import (
"strings"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
)
type version byte
@@ -75,10 +76,13 @@ func SniffHTTP(b []byte) (*SniffHeader, error) {
continue
}
key := strings.ToLower(string(parts[0]))
value := strings.ToLower(string(bytes.Trim(parts[1], " ")))
if key == "host" {
domain := strings.Split(value, ":")
sh.host = strings.TrimSpace(domain[0])
rawHost := strings.ToLower(string(bytes.TrimSpace(parts[1])))
dest, err := ParseHost(rawHost, net.Port(80))
if err != nil {
return nil, err
}
sh.host = dest.Address.String()
}
}

View File

@@ -81,6 +81,66 @@ func TestTLSHeaders(t *testing.T) {
domain: "www07.clicktale.net",
err: false,
},
{
input: []byte{
0x16, 0x03, 0x01, 0x00, 0xe6, 0x01, 0x00, 0x00, 0xe2, 0x03, 0x03, 0x81, 0x47, 0xc1,
0x66, 0xd5, 0x1b, 0xfa, 0x4b, 0xb5, 0xe0, 0x2a, 0xe1, 0xa7, 0x87, 0x13, 0x1d, 0x11, 0xaa, 0xc6,
0xce, 0xfc, 0x7f, 0xab, 0x94, 0xc8, 0x62, 0xad, 0xc8, 0xab, 0x0c, 0xdd, 0xcb, 0x20, 0x6f, 0x9d,
0x07, 0xf1, 0x95, 0x3e, 0x99, 0xd8, 0xf3, 0x6d, 0x97, 0xee, 0x19, 0x0b, 0x06, 0x1b, 0xf4, 0x84,
0x0b, 0xb6, 0x8f, 0xcc, 0xde, 0xe2, 0xd0, 0x2d, 0x6b, 0x0c, 0x1f, 0x52, 0x53, 0x13, 0x00, 0x08,
0x13, 0x02, 0x13, 0x03, 0x13, 0x01, 0x00, 0xff, 0x01, 0x00, 0x00, 0x91, 0x00, 0x00, 0x00, 0x0c,
0x00, 0x0a, 0x00, 0x00, 0x07, 0x64, 0x6f, 0x67, 0x66, 0x69, 0x73, 0x68, 0x00, 0x0b, 0x00, 0x04,
0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x0a, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e,
0x00, 0x19, 0x00, 0x18, 0x00, 0x23, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00,
0x00, 0x0d, 0x00, 0x1e, 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08,
0x08, 0x09, 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01,
0x06, 0x01, 0x00, 0x2b, 0x00, 0x07, 0x06, 0x7f, 0x1c, 0x7f, 0x1b, 0x7f, 0x1a, 0x00, 0x2d, 0x00,
0x02, 0x01, 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x2f, 0x35, 0x0c,
0xb6, 0x90, 0x0a, 0xb7, 0xd5, 0xc4, 0x1b, 0x2f, 0x60, 0xaa, 0x56, 0x7b, 0x3f, 0x71, 0xc8, 0x01,
0x7e, 0x86, 0xd3, 0xb7, 0x0c, 0x29, 0x1a, 0x9e, 0x5b, 0x38, 0x3f, 0x01, 0x72,
},
domain: "dogfish",
err: false,
},
{
input: []byte{
0x16, 0x03, 0x01, 0x01, 0x03, 0x01, 0x00, 0x00,
0xff, 0x03, 0x03, 0x3d, 0x89, 0x52, 0x9e, 0xee,
0xbe, 0x17, 0x63, 0x75, 0xef, 0x29, 0xbd, 0x14,
0x6a, 0x49, 0xe0, 0x2c, 0x37, 0x57, 0x71, 0x62,
0x82, 0x44, 0x94, 0x8f, 0x6e, 0x94, 0x08, 0x45,
0x7f, 0xdb, 0xc1, 0x00, 0x00, 0x3e, 0xc0, 0x2c,
0xc0, 0x30, 0x00, 0x9f, 0xcc, 0xa9, 0xcc, 0xa8,
0xcc, 0xaa, 0xc0, 0x2b, 0xc0, 0x2f, 0x00, 0x9e,
0xc0, 0x24, 0xc0, 0x28, 0x00, 0x6b, 0xc0, 0x23,
0xc0, 0x27, 0x00, 0x67, 0xc0, 0x0a, 0xc0, 0x14,
0x00, 0x39, 0xc0, 0x09, 0xc0, 0x13, 0x00, 0x33,
0x00, 0x9d, 0x00, 0x9c, 0x13, 0x02, 0x13, 0x03,
0x13, 0x01, 0x00, 0x3d, 0x00, 0x3c, 0x00, 0x35,
0x00, 0x2f, 0x00, 0xff, 0x01, 0x00, 0x00, 0x98,
0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00,
0x0b, 0x31, 0x30, 0x2e, 0x34, 0x32, 0x2e, 0x30,
0x2e, 0x32, 0x34, 0x33, 0x00, 0x0b, 0x00, 0x04,
0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x0a,
0x00, 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x19,
0x00, 0x18, 0x00, 0x23, 0x00, 0x00, 0x00, 0x0d,
0x00, 0x20, 0x00, 0x1e, 0x04, 0x03, 0x05, 0x03,
0x06, 0x03, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06,
0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x03,
0x02, 0x01, 0x02, 0x02, 0x04, 0x02, 0x05, 0x02,
0x06, 0x02, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17,
0x00, 0x00, 0x00, 0x2b, 0x00, 0x09, 0x08, 0x7f,
0x14, 0x03, 0x03, 0x03, 0x02, 0x03, 0x01, 0x00,
0x2d, 0x00, 0x03, 0x02, 0x01, 0x00, 0x00, 0x28,
0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20,
0x13, 0x7c, 0x6e, 0x97, 0xc4, 0xfd, 0x09, 0x2e,
0x70, 0x2f, 0x73, 0x5a, 0x9b, 0x57, 0x4d, 0x5f,
0x2b, 0x73, 0x2c, 0xa5, 0x4a, 0x98, 0x40, 0x3d,
0x75, 0x6e, 0xb4, 0x76, 0xf9, 0x48, 0x8f, 0x36,
},
domain: "10.42.0.243",
err: false,
},
}
for _, test := range cases {

View File

@@ -2,7 +2,8 @@ package task
import "v2ray.com/core/common"
func Close(v interface{}) Task {
// Close returns a func() that closes v.
func Close(v interface{}) func() error {
return func() error {
return common.Close(v)
}

View File

@@ -6,121 +6,25 @@ import (
"v2ray.com/core/common/signal/semaphore"
)
type Task func() error
type executionContext struct {
ctx context.Context
tasks []Task
onSuccess Task
onFailure Task
}
func (c *executionContext) executeTask() error {
if len(c.tasks) == 0 {
return nil
}
// Reuse current goroutine if we only have one task to run.
if len(c.tasks) == 1 && c.ctx == nil {
return c.tasks[0]()
}
ctx := context.Background()
if c.ctx != nil {
ctx = c.ctx
}
return executeParallel(ctx, c.tasks)
}
func (c *executionContext) run() error {
err := c.executeTask()
if err == nil && c.onSuccess != nil {
return c.onSuccess()
}
if err != nil && c.onFailure != nil {
return c.onFailure()
}
return err
}
type ExecutionOption func(*executionContext)
func WithContext(ctx context.Context) ExecutionOption {
return func(c *executionContext) {
c.ctx = ctx
}
}
func Parallel(tasks ...Task) ExecutionOption {
return func(c *executionContext) {
c.tasks = append(c.tasks, tasks...)
}
}
// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
// Once a task returns an error, the following tasks will not run.
func Sequential(tasks ...Task) ExecutionOption {
return func(c *executionContext) {
switch len(tasks) {
case 0:
return
case 1:
c.tasks = append(c.tasks, tasks[0])
default:
c.tasks = append(c.tasks, func() error {
return execute(tasks...)
})
}
}
}
func OnSuccess(task Task) ExecutionOption {
return func(c *executionContext) {
c.onSuccess = task
}
}
func OnFailure(task Task) ExecutionOption {
return func(c *executionContext) {
c.onFailure = task
}
}
func Single(task Task, opts ...ExecutionOption) Task {
return Run(append([]ExecutionOption{Sequential(task)}, opts...)...)
}
func Run(opts ...ExecutionOption) Task {
var c executionContext
for _, opt := range opts {
opt(&c)
}
// OnSuccess executes g() after f() returns nil.
func OnSuccess(f func() error, g func() error) func() error {
return func() error {
return c.run()
}
}
// execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
func execute(tasks ...Task) error {
for _, task := range tasks {
if err := task(); err != nil {
if err := f(); err != nil {
return err
}
return g()
}
return nil
}
// executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
func executeParallel(ctx context.Context, tasks []Task) error {
// Run executes a list of tasks in parallel, returns the first error encountered or nil if all tasks pass.
func Run(ctx context.Context, tasks ...func() error) error {
n := len(tasks)
s := semaphore.New(n)
done := make(chan error, 1)
for _, task := range tasks {
<-s.Wait()
go func(f Task) {
go func(f func() error) {
err := f()
if err == nil {
s.Signal()

View File

@@ -14,13 +14,14 @@ import (
func TestExecuteParallel(t *testing.T) {
assert := With(t)
err := Run(Parallel(func() error {
time.Sleep(time.Millisecond * 200)
return errors.New("test")
}, func() error {
time.Sleep(time.Millisecond * 500)
return errors.New("test2")
}))()
err := Run(context.Background(),
func() error {
time.Sleep(time.Millisecond * 200)
return errors.New("test")
}, func() error {
time.Sleep(time.Millisecond * 500)
return errors.New("test2")
})
assert(err.Error(), Equals, "test")
}
@@ -29,7 +30,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
assert := With(t)
ctx, cancel := context.WithCancel(context.Background())
err := Run(WithContext(ctx), Parallel(func() error {
err := Run(ctx, func() error {
time.Sleep(time.Millisecond * 2000)
return errors.New("test")
}, func() error {
@@ -38,7 +39,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
}, func() error {
cancel()
return nil
}))()
})
assert(err.Error(), HasSubstring, "canceled")
}
@@ -48,7 +49,7 @@ func BenchmarkExecuteOne(b *testing.B) {
return nil
}
for i := 0; i < b.N; i++ {
common.Must(Run(Parallel(noop))())
common.Must(Run(context.Background(), noop))
}
}
@@ -57,17 +58,6 @@ func BenchmarkExecuteTwo(b *testing.B) {
return nil
}
for i := 0; i < b.N; i++ {
common.Must(Run(Parallel(noop, noop))())
}
}
func BenchmarkExecuteContext(b *testing.B) {
noop := func() error {
return nil
}
background := context.Background()
for i := 0; i < b.N; i++ {
common.Must(Run(WithContext(background), Parallel(noop, noop))())
common.Must(Run(context.Background(), noop, noop))
}
}

View File

@@ -17,7 +17,7 @@ import (
)
const (
version = "4.8"
version = "4.10"
build = "Custom"
codename = "Po"
intro = "A unified platform for anti-censorship."

View File

@@ -1,16 +1,12 @@
package localdns
import (
"context"
"v2ray.com/core/common/net"
"v2ray.com/core/features/dns"
)
// Client is an implementation of dns.Client, which queries localhost for DNS.
type Client struct {
resolver net.Resolver
}
type Client struct{}
// Type implements common.HasType.
func (*Client) Type() interface{} {
@@ -24,16 +20,19 @@ func (*Client) Start() error { return nil }
func (*Client) Close() error { return nil }
// LookupIP implements Client.
func (c *Client) LookupIP(host string) ([]net.IP, error) {
ipAddr, err := c.resolver.LookupIPAddr(context.Background(), host)
func (*Client) LookupIP(host string) ([]net.IP, error) {
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}
ips := make([]net.IP, 0, len(ipAddr))
for _, addr := range ipAddr {
ips = append(ips, addr.IP)
parsedIPs := make([]net.IP, 0, len(ips))
for _, ip := range ips {
parsed := net.IPAddress(ip)
if parsed != nil {
parsedIPs = append(parsedIPs, parsed.IP())
}
}
return ips, nil
return parsedIPs, nil
}
// LookupIPv4 implements IPv4Lookup.
@@ -42,11 +41,10 @@ func (c *Client) LookupIPv4(host string) ([]net.IP, error) {
if err != nil {
return nil, err
}
var ipv4 []net.IP
ipv4 := make([]net.IP, 0, len(ips))
for _, ip := range ips {
parsed := net.IPAddress(ip)
if parsed.Family().IsIPv4() {
ipv4 = append(ipv4, parsed.IP())
if len(ip) == net.IPv4len {
ipv4 = append(ipv4, ip)
}
}
return ipv4, nil
@@ -58,11 +56,10 @@ func (c *Client) LookupIPv6(host string) ([]net.IP, error) {
if err != nil {
return nil, err
}
var ipv6 []net.IP
ipv6 := make([]net.IP, 0, len(ips))
for _, ip := range ips {
parsed := net.IPAddress(ip)
if parsed.Family().IsIPv6() {
ipv6 = append(ipv6, parsed.IP())
if len(ip) == net.IPv6len {
ipv6 = append(ipv6, ip)
}
}
return ipv6, nil
@@ -70,9 +67,5 @@ func (c *Client) LookupIPv6(host string) ([]net.IP, error) {
// New create a new dns.Client that queries localhost for DNS.
func New() *Client {
return &Client{
resolver: net.Resolver{
PreferGo: true,
},
}
return &Client{}
}

View File

@@ -23,7 +23,6 @@ var (
version = flag.Bool("version", false, "Show current version of V2Ray.")
test = flag.Bool("test", false, "Test config file only, without launching V2Ray server.")
format = flag.String("format", "json", "Format of input file.")
plugin = flag.Bool("plugin", false, "True to load plugins.")
)
func fileExists(file string) bool {
@@ -96,13 +95,6 @@ func main() {
return
}
if *plugin {
if err := core.LoadPlugins(); err != nil {
fmt.Println("Failed to load plugins:", err.Error())
os.Exit(-1)
}
}
server, err := startV2Ray()
if err != nil {
fmt.Println(err.Error())

View File

@@ -4,6 +4,7 @@ package core
//go:generate go install github.com/golang/mock/mockgen
//go:generate mockgen -package mocks -destination testing/mocks/io.go -mock_names Reader=Reader,Writer=Writer io Reader,Writer
//go:generate mockgen -package mocks -destination testing/mocks/log.go -mock_names Handler=LogHandler v2ray.com/core/common/log Handler
//go:generate mockgen -package mocks -destination testing/mocks/mux.go -mock_names ClientWorkerFactory=MuxClientWorkerFactory v2ray.com/core/common/mux ClientWorkerFactory
//go:generate mockgen -package mocks -destination testing/mocks/dns.go -mock_names Client=DNSClient v2ray.com/core/features/dns Client
//go:generate mockgen -package mocks -destination testing/mocks/outbound.go -mock_names Manager=OutboundManager,HandlerSelector=OutboundHandlerSelector v2ray.com/core/features/outbound Manager,HandlerSelector

View File

@@ -1,18 +0,0 @@
package core
// PluginMetadata contains some brief information regarding a plugin.
type PluginMetadata struct {
// Name of the plugin
Name string
}
// GetMetadataFuncName is the name of the function in the plugin to return PluginMetadata.
const GetMetadataFuncName = "GetPluginMetadata"
// GetMetadataFunc is the type of the function in the plugin to return PluginMetadata.
type GetMetadataFunc func() PluginMetadata
// LoadPlugins loads all possible plugins in the 'plugin' directory.
func LoadPlugins() error {
return loadPluginsInternal()
}

View File

@@ -1,46 +0,0 @@
// +build linux
package core
import (
"os"
"path/filepath"
"plugin"
"strings"
"v2ray.com/core/common/platform"
)
func loadPluginsInternal() error {
pluginPath := platform.GetPluginDirectory()
dir, err := os.Open(pluginPath)
if err != nil {
return err
}
defer dir.Close()
files, err := dir.Readdir(-1)
if err != nil {
return err
}
for _, file := range files {
if !file.IsDir() && strings.HasSuffix(file.Name(), ".so") {
p, err := plugin.Open(filepath.Join(pluginPath, file.Name()))
if err != nil {
return err
}
f, err := p.Lookup(GetMetadataFuncName)
if err != nil {
return err
}
if gmf, ok := f.(GetMetadataFunc); ok {
metadata := gmf()
newError("plugin (", metadata.Name, ") loaded.").WriteToLog()
}
}
}
return nil
}

View File

@@ -1,7 +0,0 @@
// +build !linux
package core
func loadPluginsInternal() error {
return nil
}

View File

@@ -10,7 +10,6 @@ import (
"v2ray.com/core/common"
"v2ray.com/core/transport"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/pipe"
)
// Handler is an outbound connection that silently swallow the entire payload.
@@ -36,7 +35,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
// Sleep a little here to make sure the response is sent to client.
time.Sleep(time.Second)
}
pipe.CloseError(link.Writer)
common.Interrupt(link.Writer)
return nil
}

View File

@@ -16,7 +16,6 @@ import (
"v2ray.com/core/features/policy"
"v2ray.com/core/features/routing"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/pipe"
)
func init() {
@@ -147,12 +146,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
return nil
}
if err := task.Run(task.WithContext(ctx),
task.Parallel(
task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))),
responseDone))(); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(link.Writer)), responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}

View File

@@ -160,14 +160,20 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
responseDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
if err := buf.Copy(buf.NewReader(conn), output, buf.UpdateActivity(timer)); err != nil {
var reader buf.Reader
if destination.Network == net.Network_TCP {
reader = buf.NewReader(conn)
} else {
reader = &buf.PacketReader{Reader: conn}
}
if err := buf.Copy(reader, output, buf.UpdateActivity(timer)); err != nil {
return newError("failed to process response").Base(err)
}
return nil
}
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, task.Single(responseDone, task.OnSuccess(task.Close(output)))))(); err != nil {
if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
return newError("connection ends").Base(err)
}

View File

@@ -6,7 +6,6 @@ import (
"encoding/base64"
"io"
"net/http"
"strconv"
"strings"
"time"
@@ -23,7 +22,6 @@ import (
"v2ray.com/core/features/policy"
"v2ray.com/core/features/routing"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/pipe"
)
// Server is an HTTP proxy server.
@@ -52,30 +50,11 @@ func (s *Server) policy() policy.Session {
return p
}
// Network implements proxy.Inbound.
func (*Server) Network() []net.Network {
return []net.Network{net.Network_TCP}
}
func parseHost(rawHost string, defaultPort net.Port) (net.Destination, error) {
port := defaultPort
host, rawPort, err := net.SplitHostPort(rawHost)
if err != nil {
if addrError, ok := err.(*net.AddrError); ok && strings.Contains(addrError.Err, "missing port") {
host = rawHost
} else {
return net.Destination{}, err
}
} else if len(rawPort) > 0 {
intPort, err := strconv.Atoi(rawPort)
if err != nil {
return net.Destination{}, err
}
port = net.Port(intPort)
}
return net.TCPDestination(net.ParseAddress(host), port), nil
}
func isTimeout(err error) bool {
nerr, ok := errors.Cause(err).(net.Error)
return ok && nerr.Timeout()
@@ -139,7 +118,7 @@ Start:
if len(host) == 0 {
host = request.URL.Host
}
dest, err := parseHost(host, defaultPort)
dest, err := http_proto.ParseHost(host, defaultPort)
if err != nil {
return newError("malformed proxy host: ", host).AtWarning().Base(err)
}
@@ -210,10 +189,10 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
return nil
}
var closeWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(closeWriter, responseDone))(); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
var closeWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, closeWriter, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}
@@ -307,9 +286,9 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
return nil
}
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDone))(); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
if err := task.Run(ctx, requestDone, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}

View File

@@ -62,8 +62,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(connReader, link.Writer)
}
var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
return newError("connection ends").Base(err)
}

View File

@@ -17,7 +17,6 @@ import (
"v2ray.com/core/features/policy"
"v2ray.com/core/features/routing"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/pipe"
)
var (
@@ -141,10 +140,10 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
}
var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}

View File

@@ -129,8 +129,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(responseReader, link.Writer, buf.UpdateActivity(timer))
}
var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
return newError("connection ends").Base(err)
}
@@ -167,8 +167,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return nil
}
var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
return newError("connection ends").Base(err)
}

View File

@@ -17,7 +17,6 @@ import (
"v2ray.com/core/features/routing"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/udp"
"v2ray.com/core/transport/pipe"
)
type Server struct {
@@ -229,10 +228,10 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
return nil
}
var requestDoneAndCloseWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDoneAndCloseWriter, responseDone))(); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
var requestDoneAndCloseWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}

View File

@@ -137,8 +137,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
}
}
var responseDonePost = task.Single(responseFunc, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestFunc, responseDonePost))(); err != nil {
var responseDonePost = task.OnSuccess(responseFunc, task.Close(link.Writer))
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
return newError("connection ends").Base(err)
}

View File

@@ -19,7 +19,6 @@ import (
"v2ray.com/core/features/routing"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/udp"
"v2ray.com/core/transport/pipe"
)
// Server is a SOCKS 5 proxy server
@@ -164,10 +163,10 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
return nil
}
var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}

View File

@@ -26,7 +26,6 @@ import (
"v2ray.com/core/proxy/vmess"
"v2ray.com/core/proxy/vmess/encoding"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/pipe"
)
type userByEmail struct {
@@ -302,10 +301,10 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
return transferResponse(timer, svrSession, request, response, link.Reader, writer)
}
var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return newError("connection ends").Base(err)
}

View File

@@ -161,8 +161,8 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
}
var responseDonePost = task.Single(responseDone, task.OnSuccess(task.Close(output)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDonePost))(); err != nil {
var responseDonePost = task.OnSuccess(responseDone, task.Close(output))
if err := task.Run(ctx, requestDone, responseDonePost); err != nil {
return newError("connection ends").Base(err)
}

1696
release/config/geosite.dat Executable file → Normal file

File diff suppressed because one or more lines are too long

View File

@@ -197,7 +197,7 @@ getVersion(){
else
VER=`/usr/bin/v2ray/v2ray -version 2>/dev/null`
RETVAL="$?"
CUR_VER=`echo $VER | head -n 1 | cut -d " " -f2 | cut -d. -f-2`
CUR_VER=`echo $VER | head -n 1 | cut -d " " -f2`
if [[ ${CUR_VER} != v* ]]; then
CUR_VER=v${CUR_VER}
fi
@@ -211,7 +211,7 @@ getVersion(){
return 3
elif [[ $RETVAL -ne 0 ]];then
return 2
elif [[ "$NEW_VER" != "$CUR_VER" ]];then
elif [[ `echo $NEW_VER | cut -d. -f-2` != `echo $CUR_VER | cut -d. -f-2` ]];then
return 1
fi
return 0

View File

@@ -62,13 +62,17 @@ popd
GEOIP_TAG=$(curl --silent "https://api.github.com/repos/v2ray/geoip/releases/latest" | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/')
curl -L -o release/config/geoip.dat "https://github.com/v2ray/geoip/releases/download/${GEOIP_TAG}/geoip.dat"
# Update geosite.dat
GEOIP_TAG=$(curl --silent "https://api.github.com/repos/v2ray/domain-list-community/releases/latest" | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/')
curl -L -o release/config/geosite.dat "https://github.com/v2ray/domain-list-community/releases/download/${GEOIP_TAG}/dlc.dat"
# Take a snapshot of all required source code
pushd $GOPATH/src
# Flatten vendor directories
cp -r v2ray.com/core/vendor/github.com/ ./github.com/
rm -rf v2ray.com/core/vendor/github.com/
cp -r github.com/lucas-clemente/quic-go/vendor/github.com/
cp -r v2ray.com/core/vendor/github.com/ .
rm -rf v2ray.com/core/vendor/
cp -r github.com/lucas-clemente/quic-go/vendor/github.com/ .
rm -rf github.com/lucas-clemente/quic-go/vendor/
# Create zip file for all sources
@@ -89,6 +93,8 @@ function uploadfile() {
FILE=$1
CTYPE=$(file -b --mime-type $FILE)
curl -H "Authorization: token ${GITHUB_TOKEN}" -H "Content-Type: ${CTYPE}" --data-binary @$FILE "https://uploads.github.com/repos/v2ray/v2ray-core/releases/${RELEASE_ID}/assets?name=$(basename $FILE)"
sleep 1
}
function upload() {

View File

@@ -35,6 +35,7 @@ func (m *DNSClient) EXPECT() *DNSClientMockRecorder {
// Close mocks base method
func (m *DNSClient) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
@@ -42,11 +43,13 @@ func (m *DNSClient) Close() error {
// Close indicates an expected call of Close
func (mr *DNSClientMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*DNSClient)(nil).Close))
}
// LookupIP mocks base method
func (m *DNSClient) LookupIP(arg0 string) ([]net.IP, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LookupIP", arg0)
ret0, _ := ret[0].([]net.IP)
ret1, _ := ret[1].(error)
@@ -55,11 +58,13 @@ func (m *DNSClient) LookupIP(arg0 string) ([]net.IP, error) {
// LookupIP indicates an expected call of LookupIP
func (mr *DNSClientMockRecorder) LookupIP(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupIP", reflect.TypeOf((*DNSClient)(nil).LookupIP), arg0)
}
// Start mocks base method
func (m *DNSClient) Start() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Start")
ret0, _ := ret[0].(error)
return ret0
@@ -67,11 +72,13 @@ func (m *DNSClient) Start() error {
// Start indicates an expected call of Start
func (mr *DNSClientMockRecorder) Start() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*DNSClient)(nil).Start))
}
// Type mocks base method
func (m *DNSClient) Type() interface{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Type")
ret0, _ := ret[0].(interface{})
return ret0
@@ -79,5 +86,6 @@ func (m *DNSClient) Type() interface{} {
// Type indicates an expected call of Type
func (mr *DNSClientMockRecorder) Type() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Type", reflect.TypeOf((*DNSClient)(nil).Type))
}

View File

@@ -34,6 +34,7 @@ func (m *Reader) EXPECT() *ReaderMockRecorder {
// Read mocks base method
func (m *Reader) Read(arg0 []byte) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
@@ -42,6 +43,7 @@ func (m *Reader) Read(arg0 []byte) (int, error) {
// Read indicates an expected call of Read
func (mr *ReaderMockRecorder) Read(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*Reader)(nil).Read), arg0)
}
@@ -70,6 +72,7 @@ func (m *Writer) EXPECT() *WriterMockRecorder {
// Write mocks base method
func (m *Writer) Write(arg0 []byte) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
@@ -78,5 +81,6 @@ func (m *Writer) Write(arg0 []byte) (int, error) {
// Write indicates an expected call of Write
func (mr *WriterMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*Writer)(nil).Write), arg0)
}

46
testing/mocks/log.go Normal file
View File

@@ -0,0 +1,46 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: v2ray.com/core/common/log (interfaces: Handler)
// Package mocks is a generated GoMock package.
package mocks
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
log "v2ray.com/core/common/log"
)
// LogHandler is a mock of Handler interface
type LogHandler struct {
ctrl *gomock.Controller
recorder *LogHandlerMockRecorder
}
// LogHandlerMockRecorder is the mock recorder for LogHandler
type LogHandlerMockRecorder struct {
mock *LogHandler
}
// NewLogHandler creates a new mock instance
func NewLogHandler(ctrl *gomock.Controller) *LogHandler {
mock := &LogHandler{ctrl: ctrl}
mock.recorder = &LogHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *LogHandler) EXPECT() *LogHandlerMockRecorder {
return m.recorder
}
// Handle mocks base method
func (m *LogHandler) Handle(arg0 log.Message) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Handle", arg0)
}
// Handle indicates an expected call of Handle
func (mr *LogHandlerMockRecorder) Handle(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handle", reflect.TypeOf((*LogHandler)(nil).Handle), arg0)
}

View File

@@ -35,6 +35,7 @@ func (m *MuxClientWorkerFactory) EXPECT() *MuxClientWorkerFactoryMockRecorder {
// Create mocks base method
func (m *MuxClientWorkerFactory) Create() (*mux.ClientWorker, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create")
ret0, _ := ret[0].(*mux.ClientWorker)
ret1, _ := ret[1].(error)
@@ -43,5 +44,6 @@ func (m *MuxClientWorkerFactory) Create() (*mux.ClientWorker, error) {
// Create indicates an expected call of Create
func (mr *MuxClientWorkerFactoryMockRecorder) Create() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MuxClientWorkerFactory)(nil).Create))
}

View File

@@ -36,6 +36,7 @@ func (m *OutboundManager) EXPECT() *OutboundManagerMockRecorder {
// AddHandler mocks base method
func (m *OutboundManager) AddHandler(arg0 context.Context, arg1 outbound.Handler) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddHandler", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
@@ -43,11 +44,13 @@ func (m *OutboundManager) AddHandler(arg0 context.Context, arg1 outbound.Handler
// AddHandler indicates an expected call of AddHandler
func (mr *OutboundManagerMockRecorder) AddHandler(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddHandler", reflect.TypeOf((*OutboundManager)(nil).AddHandler), arg0, arg1)
}
// Close mocks base method
func (m *OutboundManager) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
@@ -55,11 +58,13 @@ func (m *OutboundManager) Close() error {
// Close indicates an expected call of Close
func (mr *OutboundManagerMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*OutboundManager)(nil).Close))
}
// GetDefaultHandler mocks base method
func (m *OutboundManager) GetDefaultHandler() outbound.Handler {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDefaultHandler")
ret0, _ := ret[0].(outbound.Handler)
return ret0
@@ -67,11 +72,13 @@ func (m *OutboundManager) GetDefaultHandler() outbound.Handler {
// GetDefaultHandler indicates an expected call of GetDefaultHandler
func (mr *OutboundManagerMockRecorder) GetDefaultHandler() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultHandler", reflect.TypeOf((*OutboundManager)(nil).GetDefaultHandler))
}
// GetHandler mocks base method
func (m *OutboundManager) GetHandler(arg0 string) outbound.Handler {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetHandler", arg0)
ret0, _ := ret[0].(outbound.Handler)
return ret0
@@ -79,11 +86,13 @@ func (m *OutboundManager) GetHandler(arg0 string) outbound.Handler {
// GetHandler indicates an expected call of GetHandler
func (mr *OutboundManagerMockRecorder) GetHandler(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandler", reflect.TypeOf((*OutboundManager)(nil).GetHandler), arg0)
}
// RemoveHandler mocks base method
func (m *OutboundManager) RemoveHandler(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveHandler", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
@@ -91,11 +100,13 @@ func (m *OutboundManager) RemoveHandler(arg0 context.Context, arg1 string) error
// RemoveHandler indicates an expected call of RemoveHandler
func (mr *OutboundManagerMockRecorder) RemoveHandler(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveHandler", reflect.TypeOf((*OutboundManager)(nil).RemoveHandler), arg0, arg1)
}
// Start mocks base method
func (m *OutboundManager) Start() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Start")
ret0, _ := ret[0].(error)
return ret0
@@ -103,11 +114,13 @@ func (m *OutboundManager) Start() error {
// Start indicates an expected call of Start
func (mr *OutboundManagerMockRecorder) Start() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*OutboundManager)(nil).Start))
}
// Type mocks base method
func (m *OutboundManager) Type() interface{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Type")
ret0, _ := ret[0].(interface{})
return ret0
@@ -115,6 +128,7 @@ func (m *OutboundManager) Type() interface{} {
// Type indicates an expected call of Type
func (mr *OutboundManagerMockRecorder) Type() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Type", reflect.TypeOf((*OutboundManager)(nil).Type))
}
@@ -143,6 +157,7 @@ func (m *OutboundHandlerSelector) EXPECT() *OutboundHandlerSelectorMockRecorder
// Select mocks base method
func (m *OutboundHandlerSelector) Select(arg0 []string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Select", arg0)
ret0, _ := ret[0].([]string)
return ret0
@@ -150,5 +165,6 @@ func (m *OutboundHandlerSelector) Select(arg0 []string) []string {
// Select indicates an expected call of Select
func (mr *OutboundHandlerSelectorMockRecorder) Select(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Select", reflect.TypeOf((*OutboundHandlerSelector)(nil).Select), arg0)
}

View File

@@ -39,6 +39,7 @@ func (m *ProxyInbound) EXPECT() *ProxyInboundMockRecorder {
// Network mocks base method
func (m *ProxyInbound) Network() []net.Network {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Network")
ret0, _ := ret[0].([]net.Network)
return ret0
@@ -46,11 +47,13 @@ func (m *ProxyInbound) Network() []net.Network {
// Network indicates an expected call of Network
func (mr *ProxyInboundMockRecorder) Network() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Network", reflect.TypeOf((*ProxyInbound)(nil).Network))
}
// Process mocks base method
func (m *ProxyInbound) Process(arg0 context.Context, arg1 net.Network, arg2 internet.Connection, arg3 routing.Dispatcher) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Process", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
@@ -58,6 +61,7 @@ func (m *ProxyInbound) Process(arg0 context.Context, arg1 net.Network, arg2 inte
// Process indicates an expected call of Process
func (mr *ProxyInboundMockRecorder) Process(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Process", reflect.TypeOf((*ProxyInbound)(nil).Process), arg0, arg1, arg2, arg3)
}
@@ -86,6 +90,7 @@ func (m *ProxyOutbound) EXPECT() *ProxyOutboundMockRecorder {
// Process mocks base method
func (m *ProxyOutbound) Process(arg0 context.Context, arg1 *transport.Link, arg2 internet.Dialer) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Process", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
@@ -93,5 +98,6 @@ func (m *ProxyOutbound) Process(arg0 context.Context, arg1 *transport.Link, arg2
// Process indicates an expected call of Process
func (mr *ProxyOutboundMockRecorder) Process(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Process", reflect.TypeOf((*ProxyOutbound)(nil).Process), arg0, arg1, arg2)
}

View File

@@ -64,7 +64,7 @@ func (server *Server) handleConnection(conn net.Conn) {
}
pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
err := task.Run(task.Parallel(func() error {
err := task.Run(context.Background(), func() error {
defer pWriter.Close() // nolint: errcheck
for {
@@ -81,7 +81,7 @@ func (server *Server) handleConnection(conn net.Conn) {
}
}
}, func() error {
defer pReader.CloseError()
defer pReader.Interrupt()
w := buf.NewWriter(conn)
for {
@@ -96,7 +96,7 @@ func (server *Server) handleConnection(conn net.Conn) {
return err
}
}
}))()
})
if err != nil {
fmt.Println("failed to transfer data: ", err.Error())

View File

@@ -110,7 +110,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
return net.NewConnection(
net.ConnectionOutput(response.Body),
net.ConnectionInput(bwriter),
net.ConnectionOnClose(common.NewChainedClosable(breader, bwriter, response.Body)),
net.ConnectionOnClose(common.ChainedClosable{breader, bwriter, response.Body}),
), nil
}

View File

@@ -80,7 +80,7 @@ func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request)
conn := net.NewConnection(
net.ConnectionOutput(request.Body),
net.ConnectionInput(flushWriter{w: writer, d: done}),
net.ConnectionOnClose(common.NewChainedClosable(done, request.Body)),
net.ConnectionOnClose(common.ChainedClosable{done, request.Body}),
net.ConnectionLocalAddr(l.Addr()),
net.ConnectionRemoteAddr(remoteAddr),
)

View File

@@ -26,6 +26,9 @@ type KCPPacketReader struct {
func (r *KCPPacketReader) Read(b []byte) []Segment {
if r.Header != nil {
if int32(len(b)) <= r.Header.Size() {
return nil
}
b = b[r.Header.Size():]
}
if r.Security != nil {

View File

@@ -171,38 +171,10 @@ func (c *interConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
}
func (c *interConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
if mb.IsEmpty() {
return nil
}
if len(mb) == 1 {
_, err := c.Write(mb[0].Bytes())
buf.ReleaseMulti(mb)
return err
}
b := getBuffer()
defer putBuffer(b)
reader := buf.MultiBufferContainer{
MultiBuffer: mb,
}
defer reader.Close()
for {
nBytes, err := reader.Read(b[:1200])
if err != nil {
break
}
if nBytes == 0 {
continue
}
if _, err := c.Write(b[:nBytes]); err != nil {
return err
}
}
return nil
mb = buf.Compact(mb)
mb, err := buf.WriteMultiBuffer(c, mb)
buf.ReleaseMulti(mb)
return err
}
func (c *interConn) Write(b []byte) (int, error) {

View File

@@ -103,9 +103,9 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
quicConfig := &quic.Config{
ConnectionIDLength: 12,
HandshakeTimeout: time.Second * 8,
IdleTimeout: time.Second * 30,
MaxIncomingStreams: 128,
MaxIncomingUniStreams: 32,
IdleTimeout: time.Second * 45,
MaxIncomingStreams: 32,
MaxIncomingUniStreams: -1,
}
conn, err := wrapSysConn(rawConn, config)

View File

@@ -10,7 +10,7 @@ import (
)
var (
effectiveSystemDialer SystemDialer = DefaultSystemDialer{}
effectiveSystemDialer SystemDialer = &DefaultSystemDialer{}
)
type SystemDialer interface {
@@ -18,23 +18,32 @@ type SystemDialer interface {
}
type DefaultSystemDialer struct {
controllers []controller
}
func (DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: time.Second * 60,
DualStack: true,
}
if sockopt != nil {
if sockopt != nil || len(d.controllers) > 0 {
dialer.Control = func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
if err := applyOutboundSocketOptions(network, address, fd, sockopt); err != nil {
newError("failed to apply socket options").Base(err).WriteToLog(session.ExportIDToError(ctx))
if sockopt != nil {
if err := applyOutboundSocketOptions(network, address, fd, sockopt); err != nil {
newError("failed to apply socket options").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
if dest.Network == net.Network_UDP && len(sockopt.BindAddress) > 0 && sockopt.BindPort > 0 {
if err := bindAddr(fd, sockopt.BindAddress, sockopt.BindPort); err != nil {
newError("failed to bind source address to ", sockopt.BindAddress).Base(err).WriteToLog(session.ExportIDToError(ctx))
}
}
}
if dest.Network == net.Network_UDP && len(sockopt.BindAddress) > 0 && sockopt.BindPort > 0 {
if err := bindAddr(fd, sockopt.BindAddress, sockopt.BindPort); err != nil {
newError("failed to bind source address to ", sockopt.BindAddress).Base(err).WriteToLog(session.ExportIDToError(ctx))
for _, ctl := range d.controllers {
if err := ctl(network, address, fd); err != nil {
newError("failed to apply external controller").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
}
})
@@ -83,7 +92,26 @@ func (v *SimpleSystemDialer) Dial(ctx context.Context, src net.Address, dest net
// v2ray:api:stable
func UseAlternativeSystemDialer(dialer SystemDialer) {
if dialer == nil {
effectiveSystemDialer = DefaultSystemDialer{}
effectiveSystemDialer = &DefaultSystemDialer{}
}
effectiveSystemDialer = dialer
}
// RegisterDialerController adds a controller to the effective system dialer.
// The controller can be used to operate on file descriptors before they are put into use.
// It only works when effective dialer is the default dialer.
//
// v2ray:api:beta
func RegisterDialerController(ctl func(network, address string, fd uintptr) error) error {
if ctl == nil {
return newError("nil listener controller")
}
dialer, ok := effectiveSystemDialer.(*DefaultSystemDialer)
if !ok {
return newError("RegisterListenerController not supported in custom dialer")
}
dialer.controllers = append(dialer.controllers, ctl)
return nil
}

View File

@@ -15,18 +15,13 @@ var (
type conn struct {
*tls.Conn
mergingWriter *buf.BufferedWriter
}
func (c *conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
if c.mergingWriter == nil {
c.mergingWriter = buf.NewBufferedWriter(buf.NewWriter(c.Conn))
}
if err := c.mergingWriter.WriteMultiBuffer(mb); err != nil {
return err
}
return c.mergingWriter.Flush()
mb = buf.Compact(mb)
mb, err := buf.WriteMultiBuffer(c, mb)
buf.ReleaseMulti(mb)
return err
}
func (c *conn) HandshakeAddress() net.Address {

View File

@@ -18,10 +18,9 @@ var (
// connection is a wrapper for net.Conn over WebSocket connection.
type connection struct {
conn *websocket.Conn
reader io.Reader
mergingWriter *buf.BufferedWriter
remoteAddr net.Addr
conn *websocket.Conn
reader io.Reader
remoteAddr net.Addr
}
func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection {
@@ -70,13 +69,10 @@ func (c *connection) Write(b []byte) (int, error) {
}
func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
if c.mergingWriter == nil {
c.mergingWriter = buf.NewBufferedWriter(&buf.BufferToBytesWriter{Writer: c})
}
if err := c.mergingWriter.WriteMultiBuffer(mb); err != nil {
return err
}
return c.mergingWriter.Flush()
mb = buf.Compact(mb)
mb, err := buf.WriteMultiBuffer(c, mb)
buf.ReleaseMulti(mb)
return err
}
func (c *connection) Close() error {

View File

@@ -182,7 +182,8 @@ func (p *pipe) Close() error {
return nil
}
func (p *pipe) CloseError() {
// Interrupt implements common.Interruptible.
func (p *pipe) Interrupt() {
p.Lock()
defer p.Unlock()

View File

@@ -67,14 +67,3 @@ func New(opts ...Option) (*Reader, *Writer) {
pipe: p,
}
}
type closeError interface {
CloseError()
}
// CloseError invokes CloseError() method if the object is either Reader or Writer.
func CloseError(v interface{}) {
if c, ok := v.(closeError); ok {
c.CloseError()
}
}

View File

@@ -1,6 +1,7 @@
package pipe_test
import (
"context"
"io"
"sync"
"testing"
@@ -31,7 +32,7 @@ func TestPipeReadWrite(t *testing.T) {
assert(rb.String(), Equals, "abcdefg")
}
func TestPipeCloseError(t *testing.T) {
func TestPipeInterrupt(t *testing.T) {
assert := With(t)
pReader, pWriter := New(WithSizeLimit(1024))
@@ -39,7 +40,7 @@ func TestPipeCloseError(t *testing.T) {
b := buf.New()
b.Write(payload)
assert(pWriter.WriteMultiBuffer(buf.MultiBuffer{b}), IsNil)
pWriter.CloseError()
pWriter.Interrupt()
rb, err := pReader.ReadMultiBuffer()
assert(err, Equals, io.ErrClosedPipe)
@@ -73,7 +74,7 @@ func TestPipeLimitZero(t *testing.T) {
bb.Write([]byte{'a', 'b'})
assert(pWriter.WriteMultiBuffer(buf.MultiBuffer{bb}), IsNil)
err := task.Run(task.Parallel(func() error {
err := task.Run(context.Background(), func() error {
b := buf.New()
b.Write([]byte{'c', 'd'})
return pWriter.WriteMultiBuffer(buf.MultiBuffer{b})
@@ -91,7 +92,7 @@ func TestPipeLimitZero(t *testing.T) {
time.Sleep(time.Second * 2)
pWriter.Close()
return nil
}))()
})
assert(err, IsNil)
}

View File

@@ -21,7 +21,7 @@ func (r *Reader) ReadMultiBufferTimeout(d time.Duration) (buf.MultiBuffer, error
return r.pipe.ReadMultiBufferTimeout(d)
}
// CloseError sets the pipe to error state. Both reading and writing from/to the pipe will return io.ErrClosedPipe.
func (r *Reader) CloseError() {
r.pipe.CloseError()
// Interrupt implements common.Interruptible.
func (r *Reader) Interrupt() {
r.pipe.Interrupt()
}

View File

@@ -19,7 +19,7 @@ func (w *Writer) Close() error {
return w.pipe.Close()
}
// CloseError sets the pipe to error state. Both reading and writing from/to the pipe will return io.ErrClosedPipe.
func (w *Writer) CloseError() {
w.pipe.CloseError()
// Interrupt implements common.Interruptible.
func (w *Writer) Interrupt() {
w.pipe.Interrupt()
}

View File

@@ -8,7 +8,6 @@ import (
"v2ray.com/core/app/dispatcher"
"v2ray.com/core/app/proxyman"
"v2ray.com/core/common"
"v2ray.com/core/common/dice"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
@@ -19,6 +18,7 @@ import (
"v2ray.com/core/proxy/dokodemo"
"v2ray.com/core/proxy/vmess"
"v2ray.com/core/proxy/vmess/outbound"
"v2ray.com/core/testing/servers/tcp"
)
func TestV2RayDependency(t *testing.T) {
@@ -36,7 +36,8 @@ func TestV2RayDependency(t *testing.T) {
}
func TestV2RayClose(t *testing.T) {
port := net.Port(dice.RollUint16())
port := tcp.PickPort()
userId := uuid.New()
config := &Config{
App: []*serial.TypedMessage{

View File

@@ -10,9 +10,6 @@ environment:
- GOARCH: 386
- GOARCH: amd64
hosts:
quic.clemente.io: 127.0.0.1
clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
install:

View File

@@ -3,24 +3,52 @@ package quic
import (
"sync"
"v2ray.com/core/common/bytespool"
"github.com/lucas-clemente/quic-go/internal/protocol"
"v2ray.com/core/common/bytespool"
)
type packetBuffer struct {
Slice []byte
// refCount counts how many packets the Slice is used in.
// It doesn't support concurrent use.
// It is > 1 when used for coalesced packet.
refCount int
}
// Split increases the refCount.
// It must be called when a packet buffer is used for more than one packet,
// e.g. when splitting coalesced packets.
func (b *packetBuffer) Split() {
b.refCount++
}
// Release decreases the refCount.
// It should be called when processing the packet is finished.
// When the refCount reaches 0, the packet buffer is put back into the pool.
func (b *packetBuffer) Release() {
if cap(b.Slice) < 2048 {
return
}
b.refCount--
if b.refCount < 0 {
panic("negative packetBuffer refCount")
}
// only put the packetBuffer back if it's not used any more
if b.refCount == 0 {
buffer := b.Slice[0:cap(b.Slice)]
bufferPool.Put(buffer)
}
}
var bufferPool *sync.Pool
func getPacketBuffer() *[]byte {
b := bufferPool.Get().([]byte)
return &b
}
func putPacketBuffer(buf *[]byte) {
b := *buf
if cap(b) < 2048 {
return
func getPacketBuffer() *packetBuffer {
buffer := bufferPool.Get().([]byte)
return &packetBuffer{
refCount: 1,
Slice: buffer[:protocol.MaxReceivePacketSize],
}
bufferPool.Put(b[:cap(b)])
}
func init() {

View File

@@ -3,7 +3,6 @@ package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
@@ -38,6 +37,8 @@ type client struct {
destConnID protocol.ConnectionID
origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry)
initialPacketNumber protocol.PacketNumber
initialVersion protocol.VersionNumber
version protocol.VersionNumber
@@ -54,8 +55,6 @@ var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = protocol.GenerateConnectionID
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
)
// DialAddr establishes a new QUIC connection to a server.
@@ -255,7 +254,7 @@ func (c *client) dial(ctx context.Context) error {
return err
}
err := c.establishSecureConnection(ctx)
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
if err == errCloseForRecreating {
return c.dial(ctx)
}
return err
@@ -263,8 +262,7 @@ func (c *client) dial(ctx context.Context) error {
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
// - errCloseSessionRecreating when the server sends a version negotiation packet, or a stateless retry is performed
// - any other error that might occur
// - when the connection is forward-secure
func (c *client) establishSecureConnection(ctx context.Context) error {
@@ -272,7 +270,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
go func() {
err := c.session.run() // returns as soon as the session is closed
if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
if err != errCloseForRecreating && c.createdPacketConn {
c.conn.Close()
}
errorChan <- err
@@ -344,7 +342,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
c.version = newVersion
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion)
c.initialPacketNumber = c.session.closeForRecreating()
}
func (c *client) handleRetryPacket(hdr *wire.Header) {
@@ -370,7 +368,7 @@ func (c *client) handleRetryPacket(hdr *wire.Header) {
c.origDestConnID = c.destConnID
c.destConnID = hdr.SrcConnectionID
c.token = hdr.Token
c.session.destroy(errCloseSessionForRetry)
c.initialPacketNumber = c.session.closeForRecreating()
}
func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
@@ -401,6 +399,7 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
c.srcConnID,
c.config,
c.tlsConf,
c.initialPacketNumber,
params,
c.initialVersion,
c.logger,

View File

@@ -34,7 +34,7 @@ type sentPacketHandler struct {
packetNumberGenerator *packetNumberGenerator
lastSentRetransmittablePacketTime time.Time
lastSentHandshakePacketTime time.Time
lastSentCryptoPacketTime time.Time
nextPacketSendTime time.Time
@@ -56,8 +56,8 @@ type sentPacketHandler struct {
rttStats *congestion.RTTStats
handshakeComplete bool
// The number of times the handshake packets have been retransmitted without receiving an ack.
handshakeCount uint32
// The number of times the crypto packets have been retransmitted without receiving an ack.
cryptoCount uint32
// The number of times a TLP has been sent without receiving an ack.
tlpCount uint32
@@ -78,7 +78,11 @@ type sentPacketHandler struct {
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
func NewSentPacketHandler(
initialPacketNumber protocol.PacketNumber,
rttStats *congestion.RTTStats,
logger utils.Logger,
) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
@@ -88,7 +92,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) Se
)
return &sentPacketHandler{
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
packetHistory: newSentPacketHistory(),
rttStats: rttStats,
congestion: congestion,
@@ -104,21 +108,21 @@ func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
h.logger.Debugf("Handshake complete. Discarding all outstanding crypto packets.")
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.Encryption1RTT {
queue = append(queue, packet)
}
}
var handshakePackets []*Packet
var cryptoPackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.EncryptionLevel != protocol.Encryption1RTT {
handshakePackets = append(handshakePackets, p)
cryptoPackets = append(cryptoPackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
for _, p := range cryptoPackets {
h.packetHistory.Remove(p.PacketNumber)
}
h.retransmissionQueue = queue
@@ -144,8 +148,10 @@ func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retra
}
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.logger.Debugf("Skipping packet number %#x", p)
if h.logger.Debug() && h.lastSentPacketNumber != 0 {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.logger.Debugf("Skipping packet number %#x", p)
}
}
h.lastSentPacketNumber = packet.PacketNumber
@@ -161,7 +167,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
if isRetransmittable {
if packet.EncryptionLevel != protocol.Encryption1RTT {
h.lastSentHandshakePacketTime = packet.SendTime
h.lastSentCryptoPacketTime = packet.SendTime
}
h.lastSentRetransmittablePacketTime = packet.SendTime
packet.includedInBytesInFlight = true
@@ -185,7 +191,7 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
}
// duplicate or out of order ACK
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
if withPacketNumber != 0 && withPacketNumber < h.largestReceivedPacketWithAck {
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
return nil
}
@@ -299,8 +305,8 @@ func (h *sentPacketHandler) updateLossDetectionAlarm() {
return
}
if h.packetHistory.HasOutstandingHandshakePackets() {
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
if h.packetHistory.HasOutstandingCryptoPackets() {
h.alarm = h.lastSentCryptoPacketTime.Add(h.computeCryptoTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
@@ -381,12 +387,12 @@ func (h *sentPacketHandler) OnAlarm() error {
func (h *sentPacketHandler) onVerifiedAlarm() error {
var err error
if h.packetHistory.HasOutstandingHandshakePackets() {
if h.packetHistory.HasOutstandingCryptoPackets() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in handshake mode. Handshake count: %d", h.handshakeCount)
h.logger.Debugf("Loss detection alarm fired in crypto mode. Crypto count: %d", h.cryptoCount)
}
h.handshakeCount++
err = h.queueHandshakePacketsForRetransmission()
h.cryptoCount++
err = h.queueCryptoPacketsForRetransmission()
} else if !h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", h.lossTime)
@@ -456,7 +462,7 @@ func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
}
h.rtoCount = 0
h.tlpCount = 0
h.handshakeCount = 0
h.cryptoCount = 0
return h.packetHistory.Remove(p.PacketNumber)
}
@@ -575,16 +581,16 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
var handshakePackets []*Packet
func (h *sentPacketHandler) queueCryptoPacketsForRetransmission() error {
var cryptoPackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.canBeRetransmitted && p.EncryptionLevel != protocol.Encryption1RTT {
handshakePackets = append(handshakePackets, p)
cryptoPackets = append(cryptoPackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
for _, p := range cryptoPackets {
h.logger.Debugf("Queueing packet %#x as a crypto retransmission", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
@@ -603,11 +609,11 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
return nil
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
func (h *sentPacketHandler) computeCryptoTimeout() time.Duration {
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
// exponential backoff
// There's an implicit limit to this set by the handshake timeout.
return duration << h.handshakeCount
// There's an implicit limit to this set by the crypto timeout.
return duration << h.cryptoCount
}
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {

View File

@@ -10,8 +10,8 @@ type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
numOutstandingPackets int
numOutstandingHandshakePackets int
numOutstandingPackets int
numOutstandingCryptoPackets int
firstOutstanding *PacketElement
}
@@ -36,7 +36,7 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
if p.canBeRetransmitted {
h.numOutstandingPackets++
if p.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets++
h.numOutstandingCryptoPackets++
}
}
return el
@@ -107,8 +107,8 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber)
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
h.numOutstandingCryptoPackets--
if h.numOutstandingCryptoPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
@@ -148,8 +148,8 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
h.numOutstandingCryptoPackets--
if h.numOutstandingCryptoPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
@@ -163,6 +163,6 @@ func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstandingPackets > 0
}
func (h *sentPacketHistory) HasOutstandingHandshakePackets() bool {
return h.numOutstandingHandshakePackets > 0
func (h *sentPacketHistory) HasOutstandingCryptoPackets() bool {
return h.numOutstandingCryptoPackets > 0
}

View File

@@ -8,26 +8,56 @@ import (
)
type sealer struct {
iv []byte
aead cipher.AEAD
iv []byte
aead cipher.AEAD
pnEncrypter cipher.Block
// use a single slice to avoid allocations
nonceBuf []byte
pnMask []byte
// short headers protect 5 bits in the first byte, long headers only 4
is1RTT bool
}
var _ Sealer = &sealer{}
func newSealer(aead cipher.AEAD, iv []byte) Sealer {
func newSealer(aead cipher.AEAD, iv []byte, pnEncrypter cipher.Block, is1RTT bool) Sealer {
return &sealer{
iv: iv,
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
iv: iv,
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
is1RTT: is1RTT,
pnEncrypter: pnEncrypter,
pnMask: make([]byte, pnEncrypter.BlockSize()),
}
}
func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
return s.aead.Seal(dst, s.nonceBuf, src, ad)
for i := 0; i < len(s.nonceBuf); i++ {
s.nonceBuf[i] ^= s.iv[i]
}
sealed := s.aead.Seal(dst, s.nonceBuf, src, ad)
for i := 0; i < len(s.nonceBuf); i++ {
s.nonceBuf[i] = 0
}
return sealed
}
func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != s.pnEncrypter.BlockSize() {
panic("invalid sample size")
}
s.pnEncrypter.Encrypt(s.pnMask, sample)
if s.is1RTT {
*firstByte ^= s.pnMask[0] & 0x1f
} else {
*firstByte ^= s.pnMask[0] & 0xf
}
for i := range pnBytes {
pnBytes[i] ^= s.pnMask[i+1]
}
}
func (s *sealer) Overhead() int {
@@ -35,24 +65,54 @@ func (s *sealer) Overhead() int {
}
type opener struct {
iv []byte
aead cipher.AEAD
iv []byte
aead cipher.AEAD
pnDecrypter cipher.Block
// use a single slice to avoid allocations
nonceBuf []byte
pnMask []byte
// short headers protect 5 bits in the first byte, long headers only 4
is1RTT bool
}
var _ Opener = &opener{}
func newOpener(aead cipher.AEAD, iv []byte) Opener {
func newOpener(aead cipher.AEAD, iv []byte, pnDecrypter cipher.Block, is1RTT bool) Opener {
return &opener{
iv: iv,
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
iv: iv,
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
is1RTT: is1RTT,
pnDecrypter: pnDecrypter,
pnMask: make([]byte, pnDecrypter.BlockSize()),
}
}
func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
return o.aead.Open(dst, o.nonceBuf, src, ad)
for i := 0; i < len(o.nonceBuf); i++ {
o.nonceBuf[i] ^= o.iv[i]
}
opened, err := o.aead.Open(dst, o.nonceBuf, src, ad)
for i := 0; i < len(o.nonceBuf); i++ {
o.nonceBuf[i] = 0
}
return opened, err
}
func (o *opener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != o.pnDecrypter.BlockSize() {
panic("invalid sample size")
}
o.pnDecrypter.Encrypt(o.pnMask, sample)
if o.is1RTT {
*firstByte ^= o.pnMask[0] & 0x1f
} else {
*firstByte ^= o.pnMask[0] & 0xf
}
for i := range pnBytes {
pnBytes[i] ^= o.pnMask[i+1]
}
}

View File

@@ -1,12 +1,12 @@
package handshake
import (
"crypto/aes"
"crypto/tls"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
@@ -46,6 +46,11 @@ func (m messageType) String() string {
}
}
// ErrOpenerNotYetAvailable is returned when an opener is requested for an encryption level,
// but the corresponding opener has not yet been initialized
// This can happen when packets arrive out of order.
var ErrOpenerNotYetAvailable = errors.New("CryptoSetup: opener at this encryption level not yet available")
type cryptoSetup struct {
tlsConf *qtls.Config
@@ -74,7 +79,8 @@ type cryptoSetup struct {
clientHelloWrittenChan chan struct{}
initialStream io.Writer
initialAEAD crypto.AEAD
initialOpener Opener
initialSealer Sealer
handshakeStream io.Writer
handshakeOpener Opener
@@ -175,13 +181,14 @@ func newCryptoSetup(
logger utils.Logger,
perspective protocol.Perspective,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
initialAEAD, err := crypto.NewNullAEAD(connID, perspective)
initialSealer, initialOpener, err := newInitialAEAD(connID, perspective)
if err != nil {
return nil, nil, err
}
cs := &cryptoSetup{
initialStream: initialStream,
initialAEAD: initialAEAD,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
@@ -403,9 +410,19 @@ func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
}
func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
opener := newOpener(suite.AEAD(key, iv), iv)
key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "key", suite.KeyLen())
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "iv", suite.IVLen())
pnKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "pn", suite.KeyLen())
pnDecrypter, err := aes.NewCipher(pnKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
opener := newOpener(
suite.AEAD(key, iv),
iv,
pnDecrypter,
h.readEncLevel == protocol.Encryption1RTT,
)
switch h.readEncLevel {
case protocol.EncryptionInitial:
@@ -423,9 +440,19 @@ func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte)
}
func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
sealer := newSealer(suite.AEAD(key, iv), iv)
key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "key", suite.KeyLen())
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "iv", suite.IVLen())
pnKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "pn", suite.KeyLen())
pnEncrypter, err := aes.NewCipher(pnKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
sealer := newSealer(
suite.AEAD(key, iv),
iv,
pnEncrypter,
h.writeEncLevel == protocol.Encryption1RTT,
)
switch h.writeEncLevel {
case protocol.EncryptionInitial:
@@ -467,7 +494,7 @@ func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
if h.handshakeSealer != nil {
return protocol.EncryptionHandshake, h.handshakeSealer
}
return protocol.EncryptionInitial, h.initialAEAD
return protocol.EncryptionInitial, h.initialSealer
}
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
@@ -475,7 +502,7 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve
switch level {
case protocol.EncryptionInitial:
return h.initialAEAD, nil
return h.initialSealer, nil
case protocol.EncryptionHandshake:
if h.handshakeSealer == nil {
return nil, errNoSealer
@@ -491,22 +518,23 @@ func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLeve
}
}
func (h *cryptoSetup) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
return h.initialAEAD.Open(dst, src, pn, ad)
}
func (h *cryptoSetup) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
if h.handshakeOpener == nil {
return nil, errors.New("no handshake opener")
func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) {
switch level {
case protocol.EncryptionInitial:
return h.initialOpener, nil
case protocol.EncryptionHandshake:
if h.handshakeOpener == nil {
return nil, ErrOpenerNotYetAvailable
}
return h.handshakeOpener, nil
case protocol.Encryption1RTT:
if h.opener == nil {
return nil, ErrOpenerNotYetAvailable
}
return h.opener, nil
default:
return nil, fmt.Errorf("CryptoSetup: no opener with encryption level %s", level)
}
return h.handshakeOpener.Open(dst, src, pn, ad)
}
func (h *cryptoSetup) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
if h.opener == nil {
return nil, errors.New("no 1-RTT opener")
}
return h.opener.Open(dst, src, pn, ad)
}
func (h *cryptoSetup) ConnectionState() ConnectionState {

View File

@@ -0,0 +1,66 @@
package handshake
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
)
var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0}
func newInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Sealer, Opener, error) {
clientSecret, serverSecret := computeSecrets(connID)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
mySecret = clientSecret
otherSecret = serverSecret
} else {
mySecret = serverSecret
otherSecret = clientSecret
}
myKey, myPNKey, myIV := computeInitialKeyAndIV(mySecret)
otherKey, otherPNKey, otherIV := computeInitialKeyAndIV(otherSecret)
encrypterCipher, err := aes.NewCipher(myKey)
if err != nil {
return nil, nil, err
}
encrypter, err := cipher.NewGCM(encrypterCipher)
if err != nil {
return nil, nil, err
}
pnEncrypter, err := aes.NewCipher(myPNKey)
if err != nil {
return nil, nil, err
}
decrypterCipher, err := aes.NewCipher(otherKey)
if err != nil {
return nil, nil, err
}
decrypter, err := cipher.NewGCM(decrypterCipher)
if err != nil {
return nil, nil, err
}
pnDecrypter, err := aes.NewCipher(otherPNKey)
if err != nil {
return nil, nil, err
}
return newSealer(encrypter, myIV, pnEncrypter, false), newOpener(decrypter, otherIV, pnDecrypter, false), nil
}
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
initialSecret := qtls.HkdfExtract(crypto.SHA256, connID, quicVersion1Salt)
clientSecret = qtls.HkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
serverSecret = qtls.HkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
return
}
func computeInitialKeyAndIV(secret []byte) (key, pnKey, iv []byte) {
key = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
pnKey = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic hp", 16)
iv = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
return
}

View File

@@ -11,11 +11,13 @@ import (
// Opener opens a packet
type Opener interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
// Sealer seals a packet
type Sealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
Overhead() int
}
@@ -35,10 +37,7 @@ type CryptoSetup interface {
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
GetOpener(protocol.EncryptionLevel) (Opener, error)
}
// ConnectionState records basic details about the QUIC connection.

View File

@@ -59,6 +59,19 @@ func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
}
// GetOpener mocks base method
func (m *MockCryptoSetup) GetOpener(arg0 protocol.EncryptionLevel) (handshake.Opener, error) {
ret := m.ctrl.Call(m, "GetOpener", arg0)
ret0, _ := ret[0].(handshake.Opener)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOpener indicates an expected call of GetOpener
func (mr *MockCryptoSetupMockRecorder) GetOpener(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetOpener), arg0)
}
// GetSealer mocks base method
func (m *MockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) {
ret := m.ctrl.Call(m, "GetSealer")
@@ -97,45 +110,6 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
}
// Open1RTT mocks base method
func (m *MockCryptoSetup) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open1RTT indicates an expected call of Open1RTT
func (mr *MockCryptoSetupMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockCryptoSetup)(nil).Open1RTT), arg0, arg1, arg2, arg3)
}
// OpenHandshake mocks base method
func (m *MockCryptoSetup) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenHandshake indicates an expected call of OpenHandshake
func (mr *MockCryptoSetupMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).OpenHandshake), arg0, arg1, arg2, arg3)
}
// OpenInitial mocks base method
func (m *MockCryptoSetup) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenInitial indicates an expected call of OpenInitial
func (mr *MockCryptoSetupMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockCryptoSetup)(nil).OpenInitial), arg0, arg1, arg2, arg3)
}
// RunHandshake mocks base method
func (m *MockCryptoSetup) RunHandshake() error {
ret := m.ctrl.Call(m, "RunHandshake")

View File

@@ -1,10 +1,10 @@
package mocks
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
//go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener"
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"
//go:generate sh -c "../mockgen_internal.sh mocks congestion.go github.com/lucas-clemente/quic-go/internal/congestion SendAlgorithm"
//go:generate sh -c "../mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
//go:generate sh -c "../mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"

View File

@@ -0,0 +1,58 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Opener)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockOpener is a mock of Opener interface
type MockOpener struct {
ctrl *gomock.Controller
recorder *MockOpenerMockRecorder
}
// MockOpenerMockRecorder is the mock recorder for MockOpener
type MockOpenerMockRecorder struct {
mock *MockOpener
}
// NewMockOpener creates a new mock instance
func NewMockOpener(ctrl *gomock.Controller) *MockOpener {
mock := &MockOpener{ctrl: ctrl}
mock.recorder = &MockOpenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockOpener) EXPECT() *MockOpenerMockRecorder {
return m.recorder
}
// DecryptHeader mocks base method
func (m *MockOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.Call(m, "DecryptHeader", arg0, arg1, arg2)
}
// DecryptHeader indicates an expected call of DecryptHeader
func (mr *MockOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockOpener)(nil).DecryptHeader), arg0, arg1, arg2)
}
// Open mocks base method
func (m *MockOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "Open", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open indicates an expected call of Open
func (mr *MockOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockOpener)(nil).Open), arg0, arg1, arg2, arg3)
}

View File

@@ -34,6 +34,16 @@ func (m *MockSealer) EXPECT() *MockSealerMockRecorder {
return m.recorder
}
// EncryptHeader mocks base method
func (m *MockSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) {
m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2)
}
// EncryptHeader indicates an expected call of EncryptHeader
func (mr *MockSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockSealer)(nil).EncryptHeader), arg0, arg1, arg2)
}
// Overhead mocks base method
func (m *MockSealer) Overhead() int {
ret := m.ctrl.Call(m, "Overhead")

View File

@@ -16,8 +16,8 @@ const (
PacketNumberLen4 PacketNumberLen = 4
)
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func InferPacketNumber(
// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func DecodePacketNumber(
packetNumberLength PacketNumberLen,
lastPacketNumber PacketNumber,
wirePacketNumber PacketNumber,

View File

@@ -0,0 +1,18 @@
-----BEGIN CERTIFICATE-----
MIIC0DCCAbgCCQCmiwJpSoekpDANBgkqhkiG9w0BAQsFADAqMRMwEQYDVQQKDApx
dWljLWdvIENBMRMwEQYDVQQLDApxdWljLWdvIENBMB4XDTE4MTIwODA2NDIyMVoX
DTI4MTIwNTA2NDIyMVowKjETMBEGA1UECgwKcXVpYy1nbyBDQTETMBEGA1UECwwK
cXVpYy1nbyBDQTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAN5MxI09
i01xRON732BFIuxO2SGjA9jYkvUvNXK886gifp2BfWLcOW1DHkXxBnhWMqfpcIWM
GviF4G2Mp0HEJDMe+4LBxje/1e2WA+nzQlIZD6LaDi98nXJaAcCMM4a64Vm0i8Z3
+4c+O93+5TekPn507nl7QA1IaEEtoek7w7wDw4ZF3ET+nns2HwVpV/ugfuYOQbTJ
8Np+zO8EfPMTUjEpKdl4bp/yqcouWD+oIhoxmx1V+LxshcpSwtzHIAi6gjHUDCEe
bk5Y2GBT4VR5WKmNGvlfe9L0Gn0ZLJoeXDshrunF0xEmSv8MxlHcKH/u4IHiO+6x
+5sdslqY7uEPEhkCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAhvXUMiatkgsnoRHc
UobKraGttETivxvtKpc48o1TSkR+kCKbMnygmrvc5niEqc9iDg8JI6HjBKJ3/hfA
uKdyiR8cQNcQRgJ/3FVx0n3KGDUbHJSuIQzFvXom2ZPdlAHFqAT+8AVrz42v8gct
gyiGdFCSNisDbevOiRHuJtZ0m8YsGgtfU48wqGOaSSsRz4mYD6kqBFd0+Ja3/EGv
vl24L5xMCy1zGGl6wKPa7TT7ok4TfD1YmIXOfmWYop6cTLwePLj1nHrLi0AlsSn1
2pFlosc9/qEbO5drqNoxUZfeF0L9RUSuArHRSO779dW/AmOtFdK3yaBGqflg0r7p
lYombA==
-----END CERTIFICATE-----

View File

@@ -2,6 +2,9 @@ package testdata
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"io/ioutil"
"path"
"runtime"
)
@@ -14,13 +17,12 @@ func init() {
panic("Failed to get current frame")
}
certPath = path.Join(path.Dir(path.Dir(path.Dir(filename))), "example")
certPath = path.Dir(filename)
}
// GetCertificatePaths returns the paths to 'fullchain.pem' and 'privkey.pem' for the
// quic.clemente.io cert.
// GetCertificatePaths returns the paths to certificate and key
func GetCertificatePaths() (string, string) {
return path.Join(certPath, "fullchain.pem"), path.Join(certPath, "privkey.pem")
return path.Join(certPath, "cert.pem"), path.Join(certPath, "priv.key")
}
// GetTLSConfig returns a tls config for quic.clemente.io
@@ -34,11 +36,22 @@ func GetTLSConfig() *tls.Config {
}
}
// GetCertificate returns a certificate for quic.clemente.io
func GetCertificate() tls.Certificate {
cert, err := tls.LoadX509KeyPair(GetCertificatePaths())
// GetRootCA returns an x509.CertPool containing the CA certificate
func GetRootCA() *x509.CertPool {
caCertPath := path.Join(certPath, "ca.pem")
caCertRaw, err := ioutil.ReadFile(caCertPath)
if err != nil {
panic(err)
}
return cert
p, _ := pem.Decode(caCertRaw)
if p.Type != "CERTIFICATE" {
panic("expected a certificate")
}
caCert, err := x509.ParseCertificate(p.Bytes)
if err != nil {
panic(err)
}
certPool := x509.NewCertPool()
certPool.AddCert(caCert)
return certPool
}

View File

@@ -0,0 +1,18 @@
-----BEGIN CERTIFICATE-----
MIIC3jCCAcYCCQCV4BOv+SRo4zANBgkqhkiG9w0BAQUFADAqMRMwEQYDVQQKDApx
dWljLWdvIENBMRMwEQYDVQQLDApxdWljLWdvIENBMB4XDTE4MTIwODA2NDMwMloX
DTI4MTIwNTA2NDMwMlowODEQMA4GA1UECgwHcXVpYy1nbzEQMA4GA1UECwwHcXVp
Yy1nbzESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
MIIBCgKCAQEAyc/hS8XHkOJaLrdPOSTZFUBVyHNSfQUX/3dEpmccPlLQLgopYZZO
W/cVhkxAfQ3e68xKkuZKfZN5Hytn5V/AOSk281BqxFxpfCcKVYqVpDZH99+jaVfG
ImPp5Y22qCnbSEwYrMTcLiK8PVa4MkpKf1KNacVlqawU+ZWI5fevAFGTtmrMJ4S+
qZY7tAaVkax+OiKWWfhLQjJCsN3IIDysTfbWao6cYKgtTfqVChEddzS7LRJVRaB+
+huUbB87tRBJbCuJX65yB7Fw77YiKoFjc5r2845fcS2Ew4+w29mbXoj7M7g6eup5
SnCydsCvyNy6VkgaSlWS0DXvxuzWshwUrwIDAQABMA0GCSqGSIb3DQEBBQUAA4IB
AQBWgmFunf44X3/NIjNvVLeQsfGW+4L/lCi2F5tqa70Hkda+xhKACnQQGB2qCSCF
Jfxj4iKrFJ7+JB8GnribWthLuDq49PQrTI+1wKFd9c2b8DXzJLz4Onw+mPX97pZm
TflQSIxXRaFAIQuUWNTArZZEe1ESSlnaBuE5w77LMf4GMFD3P3jzSHKUyM1sF97j
gRbIt8Jw7Uyd8vlXk6m2wvO5H3hZrrhJUJH3WW13a7wLJRnff2meKU90hkLQwuxO
kyh0k/h158/r2ibiahTmQEgHs9vQaCM+HXuk5P+Tzq5Zl/n0dMFZMfkqNkD4nym/
nu7zfdwMlcBjKt9g3BGw+KE3
-----END CERTIFICATE-----

View File

@@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEAyc/hS8XHkOJaLrdPOSTZFUBVyHNSfQUX/3dEpmccPlLQLgop
YZZOW/cVhkxAfQ3e68xKkuZKfZN5Hytn5V/AOSk281BqxFxpfCcKVYqVpDZH99+j
aVfGImPp5Y22qCnbSEwYrMTcLiK8PVa4MkpKf1KNacVlqawU+ZWI5fevAFGTtmrM
J4S+qZY7tAaVkax+OiKWWfhLQjJCsN3IIDysTfbWao6cYKgtTfqVChEddzS7LRJV
RaB++huUbB87tRBJbCuJX65yB7Fw77YiKoFjc5r2845fcS2Ew4+w29mbXoj7M7g6
eup5SnCydsCvyNy6VkgaSlWS0DXvxuzWshwUrwIDAQABAoIBADunQwVO1Qqync2p
SbWueqyZc8HotL1XwBw3eQdm+yZA/GBfiJPcBhWRF7+20mkkrHwuyuxZPjOYX/ki
r3dRslQzJpcNckHQvy1/rMJUUJ9VnDhc1sTQuTR5LC46kX9rv/HC7JhFKIBKrDHF
bHURGKxCDqLxQnfA8gJEfU7cw9HnxMxmKv7qJ3O7EHYMuTQstkYsGOr60zX/C+Zm
7YA+d7nx1LpL0m2lKs70iz5MzGg+KgKyrkMWQ30gpxILBxNzzuQr7Kv/+63/3+G9
nfCGeLmwGakPFpm6/GwiABE0yGa71YNAQs18iUTZwP/ZEDw3KB2SoG8wcqWjNAd+
cUF2PgECgYEA5Xe/OZouw9h0NBo0Zut+HC0YOuUfY72Ug9Fm8bAS6wDuPiO3jIvK
J40d+ZHNp4AakfTuugiqEDJRlV7T/F2K/KHDWvXTg5ZpAC8dsZKJMxyyAp8EniYQ
vsoFWeHBfsD83rCVKLcjDB3hbQH+MSoT3lsqjZRNiNUMK13gyuX7k28CgYEA4SWF
ySRXUqUezX5D8kV5rQVYLcw6WVB3czYd7cKf8zHy4xJX0ZicyZjohknMmKCkdx+M
1mrxlqUO7EBGokM8vs87m/4rz6bjgZffpWzUmP/x1+3f3j/wIZeqNilW8NqY5nLi
tj3JxMwaesU86rOekSy27BlX4sjQ8NRs7Z2d8sECgYBKAD8kBWwVbqWy88x4cHOA
BK7ut1tTIB1YEVzgjobbULaERaJ46c/sx16mUHYBEZf///xI9Ghbxs52nFlC5qve
4xAMMoDey8/a5lbuIDKs0BE8NSoZEm+OB7qIDP0IspYZ/tprgfwEeVJshBsEoew8
Ziwn8m66tPIyvhizdk2WcwKBgH2M8RgDffaGQbESEk3N1FZZvpx7YKZhqtrCeNoX
SB7T4cAigHpPAk+hRzlref46xrvvChiftmztSm8QQNNHb15wLauFh2Taic/Ao2Sa
VcukHnbtHYPQX9Y7vx1I3ESfgdgwhKBfwF5P+wwvZRL0ax5FsxPh5hJ/LZS+wKeY
13WBAoGAXSqG3ANmCyvSLVmAXGIbr0Tuixf/a25sPrlq7Im1H1OnqLrcyxWCLV3E
6gprhG5An0Zlr/FFRxVojf0TKmtJZs9B70/6WPwVvFtBduCM1zuUuCQYU9opTJQL
ElMIP4VfjABm4tm1fqGIy1PQP0Osb6/qb2DPPJqsFiW0oRByyMA=
-----END RSA PRIVATE KEY-----

View File

@@ -1,6 +1,9 @@
package utils
import "time"
import (
"math"
"time"
)
// A Timer wrapper that behaves correctly when resetting
type Timer struct {
@@ -11,7 +14,7 @@ type Timer struct {
// NewTimer creates a new timer that is not set
func NewTimer() *Timer {
return &Timer{t: time.NewTimer(0)}
return &Timer{t: time.NewTimer(time.Duration(math.MaxInt64))}
}
// Chan returns the channel of the wrapped timer
@@ -31,7 +34,9 @@ func (t *Timer) Reset(deadline time.Time) {
if !t.t.Stop() && !t.read {
<-t.t.C
}
t.t.Reset(time.Until(deadline))
if !deadline.IsZero() {
t.t.Reset(time.Until(deadline))
}
t.read = false
t.deadline = deadline

View File

@@ -30,7 +30,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*Exte
if err != nil {
return nil, err
}
if _, err := b.Seek(int64(h.len)-1, io.SeekCurrent); err != nil {
if _, err := b.Seek(int64(h.ParsedLen())-1, io.SeekCurrent); err != nil {
return nil, err
}
if h.IsLongHeader {

View File

@@ -24,8 +24,8 @@ type Header struct {
SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet
OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet
typeByte byte
len int // how many bytes were read while parsing this header
typeByte byte
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
}
// ParseHeader parses the header.
@@ -39,7 +39,7 @@ func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
if err != nil {
return nil, err
}
h.len = startLen - b.Len()
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return h, nil
}
@@ -171,6 +171,11 @@ func (h *Header) IsVersionNegotiation() bool {
return h.IsLongHeader && h.Version == 0
}
// ParsedLen returns the number of bytes that were consumed when parsing the header
func (h *Header) ParsedLen() protocol.ByteCount {
return h.parsedLen
}
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {

View File

@@ -13,7 +13,6 @@ package quic
//go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/lucas-clemente/quic-go sealingManager"
//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker"
//go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer"
//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD"
//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner"
//go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession"
//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler"

View File

@@ -144,8 +144,8 @@ func (h *packetHandlerMap) close(e error) error {
func (h *packetHandlerMap) listen() {
for {
data := *getPacketBuffer()
data = data[:protocol.MaxReceivePacketSize]
buffer := getPacketBuffer()
data := buffer.Slice
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
n, addr, err := h.conn.ReadFrom(data)
@@ -153,55 +153,110 @@ func (h *packetHandlerMap) listen() {
h.close(err)
return
}
data = data[:n]
if err := h.handlePacket(addr, data); err != nil {
h.logger.Debugf("error handling packet from %s: %s", addr, err)
}
h.handlePacket(addr, buffer, data[:n])
}
}
func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
r := bytes.NewReader(data)
hdr, err := wire.ParseHeader(r, h.connIDLen)
// drop the packet if we can't parse the header
func (h *packetHandlerMap) handlePacket(
addr net.Addr,
buffer *packetBuffer,
data []byte,
) {
packets, err := h.parsePacket(addr, buffer, data)
if err != nil {
return fmt.Errorf("error parsing header: %s", err)
h.logger.Debugf("error parsing packets from %s: %s", addr, err)
// This is just the error from parsing the last packet.
// We still need to process the packets that were successfully parsed before.
}
p := &receivedPacket{
remoteAddr: addr,
hdr: hdr,
data: data,
rcvTime: time.Now(),
if len(packets) == 0 {
buffer.Release()
return
}
h.handleParsedPackets(packets)
}
func (h *packetHandlerMap) parsePacket(
addr net.Addr,
buffer *packetBuffer,
data []byte,
) ([]*receivedPacket, error) {
rcvTime := time.Now()
packets := make([]*receivedPacket, 0, 1)
var counter int
var lastConnID protocol.ConnectionID
for len(data) > 0 {
if counter > 0 && h.logger.Debug() {
h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes", counter, len(packets[counter-1].data))
}
hdr, err := wire.ParseHeader(bytes.NewReader(data), h.connIDLen)
// drop the packet if we can't parse the header
if err != nil {
return packets, fmt.Errorf("error parsing header: %s", err)
}
if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) {
return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID)
}
lastConnID = hdr.DestConnectionID
var rest []byte
if hdr.IsLongHeader {
if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
return packets, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
packetLen := int(hdr.ParsedLen() + hdr.Length)
rest = data[packetLen:]
data = data[:packetLen]
}
if counter > 0 {
buffer.Split()
}
counter++
packets = append(packets, &receivedPacket{
remoteAddr: addr,
hdr: hdr,
rcvTime: rcvTime,
data: data,
buffer: buffer,
})
data = rest
}
return packets, nil
}
func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) {
h.mutex.RLock()
defer h.mutex.RUnlock()
handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)]
// coalesced packets all have the same destination connection ID
handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)]
if handlerFound { // existing session
handlerEntry.handler.handlePacket(p)
return nil
}
// No session found.
// This might be a stateless reset.
if !hdr.IsLongHeader {
if len(data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
sess.destroy(errors.New("received a stateless reset"))
return nil
}
for _, p := range packets {
if handlerFound { // existing session
handlerEntry.handler.handlePacket(p)
continue
}
// TODO(#943): send a stateless reset
return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID)
// No session found.
// This might be a stateless reset.
if !p.hdr.IsLongHeader {
if len(p.data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], p.data[len(p.data)-16:])
if sess, ok := h.resetTokens[token]; ok {
sess.destroy(errors.New("received a stateless reset"))
continue
}
}
// TODO(#943): send a stateless reset
h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
break // a short header packet is always the last in a coalesced packet
}
if h.server != nil { // no server set
h.server.handlePacket(p)
}
h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID)
}
if h.server == nil { // no server set
return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
}
h.server.handlePacket(p)
return nil
}

View File

@@ -25,10 +25,25 @@ type packer interface {
}
type packedPacket struct {
header *wire.ExtendedHeader
raw []byte
frames []wire.Frame
encryptionLevel protocol.EncryptionLevel
header *wire.ExtendedHeader
raw []byte
frames []wire.Frame
buffer *packetBuffer
}
func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel {
if !p.header.IsLongHeader {
return protocol.Encryption1RTT
}
switch p.header.Type {
case protocol.PacketTypeInitial:
return protocol.EncryptionInitial
case protocol.PacketTypeHandshake:
return protocol.EncryptionHandshake
default:
return protocol.EncryptionUnspecified
}
}
func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
@@ -37,7 +52,7 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
PacketType: p.header.Type,
Frames: p.frames,
Length: protocol.ByteCount(len(p.raw)),
EncryptionLevel: p.encryptionLevel,
EncryptionLevel: p.EncryptionLevel(),
SendTime: time.Now(),
}
}
@@ -136,13 +151,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
frames := []wire.Frame{ccf}
encLevel, sealer := p.cryptoSetup.GetSealer()
header := p.getHeader(encLevel)
raw, err := p.writeAndSealPacket(header, frames, sealer)
return &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
return p.writeAndSealPacket(header, frames, sealer)
}
func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
@@ -154,13 +163,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
encLevel, sealer := p.cryptoSetup.GetSealer()
header := p.getHeader(encLevel)
frames := []wire.Frame{ack}
raw, err := p.writeAndSealPacket(header, frames, sealer)
return &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
return p.writeAndSealPacket(header, frames, sealer)
}
// PackRetransmission packs a retransmission
@@ -227,16 +230,11 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
sf.DataLenPresent = false
}
raw, err := p.writeAndSealPacket(header, frames, sealer)
p, err := p.writeAndSealPacket(header, frames, sealer)
if err != nil {
return nil, err
}
packets = append(packets, &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
})
packets = append(packets, p)
}
return packets, nil
}
@@ -281,16 +279,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
p.numNonRetransmittableAcks = 0
}
raw, err := p.writeAndSealPacket(header, frames, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
header: header,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, nil
return p.writeAndSealPacket(header, frames, sealer)
}
func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
@@ -320,16 +309,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
}
cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
frames = append(frames, cf)
raw, err := p.writeAndSealPacket(hdr, frames, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
header: hdr,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, nil
return p.writeAndSealPacket(hdr, frames, sealer)
}
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wire.Frame, error) {
@@ -395,9 +375,9 @@ func (p *packetPacker) writeAndSealPacket(
header *wire.ExtendedHeader,
frames []wire.Frame,
sealer handshake.Sealer,
) ([]byte, error) {
raw := *getPacketBuffer()
buffer := bytes.NewBuffer(raw[:0])
) (*packedPacket, error) {
packetBuffer := getPacketBuffer()
buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
@@ -421,7 +401,7 @@ func (p *packetPacker) writeAndSealPacket(
if err := header.Write(buffer, p.version); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()
payloadOffset := buffer.Len()
// write all frames but the last one
for _, frame := range frames[:len(frames)-1] {
@@ -436,7 +416,7 @@ func (p *packetPacker) writeAndSealPacket(
sf.DataLenPresent = true
}
} else {
payloadLen := buffer.Len() - payloadStartIndex + int(lastFrame.Length(p.version))
payloadLen := buffer.Len() - payloadOffset + int(lastFrame.Length(p.version))
if paddingLen := 4 - int(header.PacketNumberLen) - payloadLen; paddingLen > 0 {
// Pad the packet such that packet number length + payload length is 4 bytes.
// This is needed to enable the peer to get a 16 byte sample for header protection.
@@ -458,15 +438,27 @@ func (p *packetPacker) writeAndSealPacket(
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
}
raw = raw[0:buffer.Len()]
_ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], header.PacketNumber, raw[:payloadStartIndex])
raw := buffer.Bytes()
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset])
raw = raw[0 : buffer.Len()+sealer.Overhead()]
pnOffset := payloadOffset - int(header.PacketNumberLen)
sealer.EncryptHeader(
raw[pnOffset+4:pnOffset+4+16],
&raw[0],
raw[pnOffset:payloadOffset],
)
num := p.pnManager.PopPacketNumber()
if num != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
return raw, nil
return &packedPacket{
header: header,
raw: raw,
frames: frames,
buffer: packetBuffer,
}, nil
}
func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {

Some files were not shown because too many files have changed in this diff Show More