Compare commits

..

14 Commits

Author SHA1 Message Date
Darien Raymond
6b355ef461 fix a typo in local dns 2018-12-06 21:34:05 +01:00
Darien Raymond
f49f85b5bd fix #1462 2018-12-06 21:09:41 +01:00
Darien Raymond
21b3f66b8b fix IP parsing in local dns client 2018-12-06 20:11:45 +01:00
Darien Raymond
30b5bffad4 support custom log handler 2018-12-06 17:37:05 +01:00
Darien Raymond
b9450d8475 Revert "use default logger for android and ios"
This reverts commit 9743380e2d.
2018-12-06 17:03:15 +01:00
Darien Raymond
50e77cbb19 fix broken test 2018-12-06 14:44:24 +01:00
Darien Raymond
9743380e2d use default logger for android and ios 2018-12-06 14:40:45 +01:00
Darien Raymond
427679e66d simplify task execution 2018-12-06 11:35:02 +01:00
Darien Raymond
cf1705267e switch to errgroup 2018-12-06 10:22:14 +01:00
Darien Raymond
c89183e6b3 update port picking 2018-12-05 16:27:32 +01:00
Darien Raymond
4104a86b6c use default dns resolver to prevent errors in android 2018-12-05 15:48:40 +01:00
Darien Raymond
82d562d1f0 use session.Outbound.ResolvedIPs 2018-12-04 20:36:51 +01:00
Darien Raymond
98d89aebc2 fix release script 2018-12-04 19:40:12 +01:00
Darien Raymond
72f5e9de16 fix vendor directory 2018-12-04 18:59:14 +01:00
29 changed files with 402 additions and 300 deletions

View File

@@ -31,32 +31,24 @@ func New(ctx context.Context, config *Config) (*Instance, error) {
}
func (g *Instance) initAccessLogger() error {
switch g.config.AccessLogType {
case LogType_File:
creator, err := log.CreateFileLogWriter(g.config.AccessLogPath)
if err != nil {
return err
}
g.accessLogger = log.NewLogger(creator)
case LogType_Console:
g.accessLogger = log.NewLogger(log.CreateStdoutLogWriter())
default:
handler, err := createHandler(g.config.AccessLogType, HandlerCreatorOptions{
Path: g.config.AccessLogPath,
})
if err != nil {
return err
}
g.accessLogger = handler
return nil
}
func (g *Instance) initErrorLogger() error {
switch g.config.ErrorLogType {
case LogType_File:
creator, err := log.CreateFileLogWriter(g.config.ErrorLogPath)
if err != nil {
return err
}
g.errorLogger = log.NewLogger(creator)
case LogType_Console:
g.errorLogger = log.NewLogger(log.CreateStdoutLogWriter())
default:
handler, err := createHandler(g.config.ErrorLogType, HandlerCreatorOptions{
Path: g.config.ErrorLogPath,
})
if err != nil {
return err
}
g.errorLogger = handler
return nil
}

51
app/log/log_creator.go Normal file
View File

@@ -0,0 +1,51 @@
package log
import (
"v2ray.com/core/common"
"v2ray.com/core/common/log"
)
type HandlerCreatorOptions struct {
Path string
}
type HandlerCreator func(LogType, HandlerCreatorOptions) (log.Handler, error)
var (
handlerCreatorMap = make(map[LogType]HandlerCreator)
)
func RegisterHandlerCreator(logType LogType, f HandlerCreator) error {
if f == nil {
return newError("nil HandlerCreator")
}
handlerCreatorMap[logType] = f
return nil
}
func createHandler(logType LogType, options HandlerCreatorOptions) (log.Handler, error) {
creator, found := handlerCreatorMap[logType]
if !found {
return nil, newError("unable to create log handler for ", logType)
}
return creator(logType, options)
}
func init() {
common.Must(RegisterHandlerCreator(LogType_Console, func(lt LogType, options HandlerCreatorOptions) (log.Handler, error) {
return log.NewLogger(log.CreateStdoutLogWriter()), nil
}))
common.Must(RegisterHandlerCreator(LogType_File, func(lt LogType, options HandlerCreatorOptions) (log.Handler, error) {
creator, err := log.CreateFileLogWriter(options.Path)
if err != nil {
return nil, err
}
return log.NewLogger(creator), nil
}))
common.Must(RegisterHandlerCreator(LogType_None, func(lt LogType, options HandlerCreatorOptions) (log.Handler, error) {
return nil, nil
}))
}

52
app/log/log_test.go Normal file
View File

@@ -0,0 +1,52 @@
package log_test
import (
"context"
"testing"
"github.com/golang/mock/gomock"
"v2ray.com/core/app/log"
"v2ray.com/core/common"
clog "v2ray.com/core/common/log"
"v2ray.com/core/testing/mocks"
)
func TestCustomLogHandler(t *testing.T) {
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
var loggedValue []string
mockHandler := mocks.NewLogHandler(mockCtl)
mockHandler.EXPECT().Handle(gomock.Any()).AnyTimes().DoAndReturn(func(msg clog.Message) {
loggedValue = append(loggedValue, msg.String())
})
log.RegisterHandlerCreator(log.LogType_Console, func(lt log.LogType, options log.HandlerCreatorOptions) (clog.Handler, error) {
return mockHandler, nil
})
logger, err := log.New(context.Background(), &log.Config{
ErrorLogLevel: clog.Severity_Debug,
ErrorLogType: log.LogType_Console,
AccessLogType: log.LogType_None,
})
common.Must(err)
common.Must(logger.Start())
clog.Record(&clog.GeneralMessage{
Severity: clog.Severity_Debug,
Content: "test",
})
if len(loggedValue) < 2 {
t.Fatal("expected 2 log messages, but actually ", loggedValue)
}
if loggedValue[1] != "[Debug] test" {
t.Fatal("expected '[Debug] test', but actually ", loggedValue[1])
}
common.Must(logger.Close())
}

View File

@@ -111,9 +111,18 @@ func targetFromContent(ctx context.Context) net.Destination {
return outbound.Target
}
func resolvedIPFromContext(ctx context.Context) []net.IP {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
return nil
}
return outbound.ResolvedIPs
}
type MultiGeoIPMatcher struct {
matchers []*GeoIPMatcher
destFunc func(context.Context) net.Destination
matchers []*GeoIPMatcher
destFunc func(context.Context) net.Destination
resolvedIPFunc func(context.Context) []net.IP
}
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
@@ -126,17 +135,18 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
matchers = append(matchers, matcher)
}
var destFunc func(context.Context) net.Destination
if onSource {
destFunc = sourceFromContext
} else {
destFunc = targetFromContent
matcher := &MultiGeoIPMatcher{
matchers: matchers,
}
return &MultiGeoIPMatcher{
matchers: matchers,
destFunc: destFunc,
}, nil
if onSource {
matcher.destFunc = sourceFromContext
} else {
matcher.destFunc = targetFromContent
matcher.resolvedIPFunc = resolvedIPFromContext
}
return matcher, nil
}
func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
@@ -146,10 +156,12 @@ func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
if dest.IsValid() && dest.Address.Family().IsIP() {
ips = append(ips, dest.Address.IP())
} else if resolver, ok := ResolvedIPsFromContext(ctx); ok {
resolvedIPs := resolver.Resolve()
for _, rip := range resolvedIPs {
ips = append(ips, rip.IP())
}
if m.resolvedIPFunc != nil {
rips := m.resolvedIPFunc(ctx)
if len(rips) > 0 {
ips = append(ips, rips...)
}
}

View File

@@ -7,32 +7,12 @@ import (
"v2ray.com/core"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/session"
"v2ray.com/core/features/dns"
"v2ray.com/core/features/outbound"
"v2ray.com/core/features/routing"
)
type key uint32
const (
resolvedIPsKey key = iota
)
type IPResolver interface {
Resolve() []net.Address
}
func ContextWithResolveIPs(ctx context.Context, f IPResolver) context.Context {
return context.WithValue(ctx, resolvedIPsKey, f)
}
func ResolvedIPsFromContext(ctx context.Context) (IPResolver, bool) {
ips, ok := ctx.Value(resolvedIPsKey).(IPResolver)
return ips, ok
}
func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
r := new(Router)
@@ -91,34 +71,6 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error
return nil
}
type ipResolver struct {
dns dns.Client
ip []net.Address
domain string
resolved bool
}
func (r *ipResolver) Resolve() []net.Address {
if r.resolved {
return r.ip
}
newError("looking for IP for domain: ", r.domain).WriteToLog()
r.resolved = true
ips, err := r.dns.LookupIP(r.domain)
if err != nil {
newError("failed to get IP address").Base(err).WriteToLog()
}
if len(ips) == 0 {
return nil
}
r.ip = make([]net.Address, len(ips))
for i, ip := range ips {
r.ip[i] = net.IPAddress(ip)
}
return r.ip
}
func (r *Router) PickRoute(ctx context.Context) (string, error) {
rule, err := r.pickRouteInternal(ctx)
if err != nil {
@@ -127,17 +79,27 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) {
return rule.GetTag()
}
// PickRoute implements routing.Router.
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
resolver := &ipResolver{
dns: r.dns,
func isDomainOutbound(outbound *session.Outbound) bool {
return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
}
func (r *Router) resolveIP(outbound *session.Outbound) error {
domain := outbound.Target.Address.Domain()
ips, err := r.dns.LookupIP(domain)
if err != nil {
return err
}
outbound.ResolvedIPs = ips
return nil
}
// PickRoute implements routing.Router.
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
outbound := session.OutboundFromContext(ctx)
if r.domainStrategy == Config_IpOnDemand {
if outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() {
resolver.domain = outbound.Target.Address.Domain()
ctx = ContextWithResolveIPs(ctx, resolver)
if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) {
if err := r.resolveIP(outbound); err != nil {
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
}
@@ -147,21 +109,19 @@ func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
}
}
if outbound == nil || !outbound.Target.IsValid() {
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) {
return nil, common.ErrNoClue
}
dest := outbound.Target
if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() {
resolver.domain = dest.Address.Domain()
ips := resolver.Resolve()
if len(ips) > 0 {
ctx = ContextWithResolveIPs(ctx, resolver)
for _, rule := range r.rules {
if rule.Apply(ctx) {
return rule, nil
}
}
if err := r.resolveIP(outbound); err != nil {
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
return nil, common.ErrNoClue
}
// Try applying rules again if we have IPs.
for _, rule := range r.rules {
if rule.Apply(ctx) {
return rule, nil
}
}

View File

@@ -125,3 +125,72 @@ func TestIPOnDemand(t *testing.T) {
t.Error("expect tag 'test', bug actually ", tag)
}
}
func TestIPIfNonMatchDomain(t *testing.T) {
config := &Config{
DomainStrategy: Config_IpIfNonMatch,
Rule: []*RoutingRule{
{
TargetTag: &RoutingRule_Tag{
Tag: "test",
},
Cidr: []*CIDR{
{
Ip: []byte{192, 168, 0, 0},
Prefix: 16,
},
},
},
},
}
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl)
mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes()
r := new(Router)
common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)
}
}
func TestIPIfNonMatchIP(t *testing.T) {
config := &Config{
DomainStrategy: Config_IpIfNonMatch,
Rule: []*RoutingRule{
{
TargetTag: &RoutingRule_Tag{
Tag: "test",
},
Cidr: []*CIDR{
{
Ip: []byte{127, 0, 0, 0},
Prefix: 8,
},
},
},
},
}
mockCtl := gomock.NewController(t)
defer mockCtl.Finish()
mockDns := mocks.NewDNSClient(mockCtl)
r := new(Router)
common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
tag, err := r.PickRoute(ctx)
common.Must(err)
if tag != "test" {
t.Error("expect tag 'test', bug actually ", tag)
}
}

View File

@@ -7,6 +7,7 @@ import (
"net"
"testing"
"golang.org/x/sync/errgroup"
"v2ray.com/core/common"
. "v2ray.com/core/common/buf"
"v2ray.com/core/common/compare"
@@ -31,12 +32,17 @@ func TestReadvReader(t *testing.T) {
data := make([]byte, 8192)
common.Must2(rand.Read(data))
go func() {
var errg errgroup.Group
errg.Go(func() error {
writer := NewWriter(conn)
mb := MergeBytes(nil, data)
if err := writer.WriteMultiBuffer(mb); err != nil {
t.Fatal("failed to write data: ", err)
return writer.WriteMultiBuffer(mb)
})
defer func() {
if err := errg.Wait(); err != nil {
t.Error(err)
}
}()

View File

@@ -2,7 +2,8 @@ package task
import "v2ray.com/core/common"
func Close(v interface{}) Task {
// Close returns a func() that closes v.
func Close(v interface{}) func() error {
return func() error {
return common.Close(v)
}

View File

@@ -6,121 +6,25 @@ import (
"v2ray.com/core/common/signal/semaphore"
)
type Task func() error
type executionContext struct {
ctx context.Context
tasks []Task
onSuccess Task
onFailure Task
}
func (c *executionContext) executeTask() error {
if len(c.tasks) == 0 {
return nil
}
// Reuse current goroutine if we only have one task to run.
if len(c.tasks) == 1 && c.ctx == nil {
return c.tasks[0]()
}
ctx := context.Background()
if c.ctx != nil {
ctx = c.ctx
}
return executeParallel(ctx, c.tasks)
}
func (c *executionContext) run() error {
err := c.executeTask()
if err == nil && c.onSuccess != nil {
return c.onSuccess()
}
if err != nil && c.onFailure != nil {
return c.onFailure()
}
return err
}
type ExecutionOption func(*executionContext)
func WithContext(ctx context.Context) ExecutionOption {
return func(c *executionContext) {
c.ctx = ctx
}
}
func Parallel(tasks ...Task) ExecutionOption {
return func(c *executionContext) {
c.tasks = append(c.tasks, tasks...)
}
}
// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
// Once a task returns an error, the following tasks will not run.
func Sequential(tasks ...Task) ExecutionOption {
return func(c *executionContext) {
switch len(tasks) {
case 0:
return
case 1:
c.tasks = append(c.tasks, tasks[0])
default:
c.tasks = append(c.tasks, func() error {
return execute(tasks...)
})
}
}
}
func OnSuccess(task Task) ExecutionOption {
return func(c *executionContext) {
c.onSuccess = task
}
}
func OnFailure(task Task) ExecutionOption {
return func(c *executionContext) {
c.onFailure = task
}
}
func Single(task Task, opts ...ExecutionOption) Task {
return Run(append([]ExecutionOption{Sequential(task)}, opts...)...)
}
func Run(opts ...ExecutionOption) Task {
var c executionContext
for _, opt := range opts {
opt(&c)
}
// OnSuccess executes g() after f() returns nil.
func OnSuccess(f func() error, g func() error) func() error {
return func() error {
return c.run()
}
}
// execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
func execute(tasks ...Task) error {
for _, task := range tasks {
if err := task(); err != nil {
if err := f(); err != nil {
return err
}
return g()
}
return nil
}
// executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
func executeParallel(ctx context.Context, tasks []Task) error {
// Run executes a list of tasks in parallel, returns the first error encountered or nil if all tasks pass.
func Run(ctx context.Context, tasks ...func() error) error {
n := len(tasks)
s := semaphore.New(n)
done := make(chan error, 1)
for _, task := range tasks {
<-s.Wait()
go func(f Task) {
go func(f func() error) {
err := f()
if err == nil {
s.Signal()

View File

@@ -14,13 +14,14 @@ import (
func TestExecuteParallel(t *testing.T) {
assert := With(t)
err := Run(Parallel(func() error {
time.Sleep(time.Millisecond * 200)
return errors.New("test")
}, func() error {
time.Sleep(time.Millisecond * 500)
return errors.New("test2")
}))()
err := Run(context.Background(),
func() error {
time.Sleep(time.Millisecond * 200)
return errors.New("test")
}, func() error {
time.Sleep(time.Millisecond * 500)
return errors.New("test2")
})
assert(err.Error(), Equals, "test")
}
@@ -29,7 +30,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
assert := With(t)
ctx, cancel := context.WithCancel(context.Background())
err := Run(WithContext(ctx), Parallel(func() error {
err := Run(ctx, func() error {
time.Sleep(time.Millisecond * 2000)
return errors.New("test")
}, func() error {
@@ -38,7 +39,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
}, func() error {
cancel()
return nil
}))()
})
assert(err.Error(), HasSubstring, "canceled")
}
@@ -48,7 +49,7 @@ func BenchmarkExecuteOne(b *testing.B) {
return nil
}
for i := 0; i < b.N; i++ {
common.Must(Run(Parallel(noop))())
common.Must(Run(context.Background(), noop))
}
}
@@ -57,17 +58,6 @@ func BenchmarkExecuteTwo(b *testing.B) {
return nil
}
for i := 0; i < b.N; i++ {
common.Must(Run(Parallel(noop, noop))())
}
}
func BenchmarkExecuteContext(b *testing.B) {
noop := func() error {
return nil
}
background := context.Background()
for i := 0; i < b.N; i++ {
common.Must(Run(WithContext(background), Parallel(noop, noop))())
common.Must(Run(context.Background(), noop, noop))
}
}

View File

@@ -1,16 +1,12 @@
package localdns
import (
"context"
"v2ray.com/core/common/net"
"v2ray.com/core/features/dns"
)
// Client is an implementation of dns.Client, which queries localhost for DNS.
type Client struct {
resolver net.Resolver
}
type Client struct{}
// Type implements common.HasType.
func (*Client) Type() interface{} {
@@ -24,16 +20,19 @@ func (*Client) Start() error { return nil }
func (*Client) Close() error { return nil }
// LookupIP implements Client.
func (c *Client) LookupIP(host string) ([]net.IP, error) {
ipAddr, err := c.resolver.LookupIPAddr(context.Background(), host)
func (*Client) LookupIP(host string) ([]net.IP, error) {
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}
ips := make([]net.IP, 0, len(ipAddr))
for _, addr := range ipAddr {
ips = append(ips, addr.IP)
parsedIPs := make([]net.IP, 0, len(ips))
for _, ip := range ips {
parsed := net.IPAddress(ip)
if parsed != nil {
parsedIPs = append(parsedIPs, parsed.IP())
}
}
return ips, nil
return parsedIPs, nil
}
// LookupIPv4 implements IPv4Lookup.
@@ -42,11 +41,10 @@ func (c *Client) LookupIPv4(host string) ([]net.IP, error) {
if err != nil {
return nil, err
}
var ipv4 []net.IP
ipv4 := make([]net.IP, 0, len(ips))
for _, ip := range ips {
parsed := net.IPAddress(ip)
if parsed.Family().IsIPv4() {
ipv4 = append(ipv4, parsed.IP())
if len(ip) == net.IPv4len {
ipv4 = append(ipv4, ip)
}
}
return ipv4, nil
@@ -58,11 +56,10 @@ func (c *Client) LookupIPv6(host string) ([]net.IP, error) {
if err != nil {
return nil, err
}
var ipv6 []net.IP
ipv6 := make([]net.IP, 0, len(ips))
for _, ip := range ips {
parsed := net.IPAddress(ip)
if parsed.Family().IsIPv6() {
ipv6 = append(ipv6, parsed.IP())
if len(ip) == net.IPv6len {
ipv6 = append(ipv6, ip)
}
}
return ipv6, nil
@@ -70,9 +67,5 @@ func (c *Client) LookupIPv6(host string) ([]net.IP, error) {
// New create a new dns.Client that queries localhost for DNS.
func New() *Client {
return &Client{
resolver: net.Resolver{
PreferGo: true,
},
}
return &Client{}
}

View File

@@ -4,6 +4,7 @@ package core
//go:generate go install github.com/golang/mock/mockgen
//go:generate mockgen -package mocks -destination testing/mocks/io.go -mock_names Reader=Reader,Writer=Writer io Reader,Writer
//go:generate mockgen -package mocks -destination testing/mocks/log.go -mock_names Handler=LogHandler v2ray.com/core/common/log Handler
//go:generate mockgen -package mocks -destination testing/mocks/mux.go -mock_names ClientWorkerFactory=MuxClientWorkerFactory v2ray.com/core/common/mux ClientWorkerFactory
//go:generate mockgen -package mocks -destination testing/mocks/dns.go -mock_names Client=DNSClient v2ray.com/core/features/dns Client
//go:generate mockgen -package mocks -destination testing/mocks/outbound.go -mock_names Manager=OutboundManager,HandlerSelector=OutboundHandlerSelector v2ray.com/core/features/outbound Manager,HandlerSelector

View File

@@ -147,10 +147,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
return nil
}
if err := task.Run(task.WithContext(ctx),
task.Parallel(
task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))),
responseDone))(); err != nil {
if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(link.Writer)), responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@@ -167,7 +167,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return nil
}
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, task.Single(responseDone, task.OnSuccess(task.Close(output)))))(); err != nil {
if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
return newError("connection ends").Base(err)
}

View File

@@ -210,8 +210,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
return nil
}
var closeWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(closeWriter, responseDone))(); err != nil {
var closeWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, closeWriter, responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)
@@ -307,7 +307,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
return nil
}
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDone))(); err != nil {
if err := task.Run(ctx, requestDone, responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@@ -62,8 +62,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(connReader, link.Writer)
}
var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
return newError("connection ends").Base(err)
}

View File

@@ -141,8 +141,8 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
}
var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@@ -129,8 +129,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return buf.Copy(responseReader, link.Writer, buf.UpdateActivity(timer))
}
var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
return newError("connection ends").Base(err)
}
@@ -167,8 +167,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
return nil
}
var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
return newError("connection ends").Base(err)
}

View File

@@ -229,8 +229,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
return nil
}
var requestDoneAndCloseWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDoneAndCloseWriter, responseDone))(); err != nil {
var requestDoneAndCloseWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@@ -137,8 +137,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
}
}
var responseDonePost = task.Single(responseFunc, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestFunc, responseDonePost))(); err != nil {
var responseDonePost = task.OnSuccess(responseFunc, task.Close(link.Writer))
if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
return newError("connection ends").Base(err)
}

View File

@@ -164,8 +164,8 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
return nil
}
var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@@ -302,8 +302,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
return transferResponse(timer, svrSession, request, response, link.Reader, writer)
}
var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
pipe.CloseError(link.Reader)
pipe.CloseError(link.Writer)
return newError("connection ends").Base(err)

View File

@@ -161,8 +161,8 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
}
var responseDonePost = task.Single(responseDone, task.OnSuccess(task.Close(output)))
if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDonePost))(); err != nil {
var responseDonePost = task.OnSuccess(responseDone, task.Close(output))
if err := task.Run(ctx, requestDone, responseDonePost); err != nil {
return newError("connection ends").Base(err)
}

View File

@@ -66,9 +66,9 @@ curl -L -o release/config/geoip.dat "https://github.com/v2ray/geoip/releases/dow
pushd $GOPATH/src
# Flatten vendor directories
cp -r v2ray.com/core/vendor/github.com/ ./github.com/
rm -rf v2ray.com/core/vendor/github.com/
cp -r github.com/lucas-clemente/quic-go/vendor/github.com/
cp -r v2ray.com/core/vendor/github.com/ .
rm -rf v2ray.com/core/vendor/
cp -r github.com/lucas-clemente/quic-go/vendor/github.com/ .
rm -rf github.com/lucas-clemente/quic-go/vendor/
# Create zip file for all sources

44
testing/mocks/log.go Normal file
View File

@@ -0,0 +1,44 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: v2ray.com/core/common/log (interfaces: Handler)
// Package mocks is a generated GoMock package.
package mocks
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
log "v2ray.com/core/common/log"
)
// LogHandler is a mock of Handler interface
type LogHandler struct {
ctrl *gomock.Controller
recorder *LogHandlerMockRecorder
}
// LogHandlerMockRecorder is the mock recorder for LogHandler
type LogHandlerMockRecorder struct {
mock *LogHandler
}
// NewLogHandler creates a new mock instance
func NewLogHandler(ctrl *gomock.Controller) *LogHandler {
mock := &LogHandler{ctrl: ctrl}
mock.recorder = &LogHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *LogHandler) EXPECT() *LogHandlerMockRecorder {
return m.recorder
}
// Handle mocks base method
func (m *LogHandler) Handle(arg0 log.Message) {
m.ctrl.Call(m, "Handle", arg0)
}
// Handle indicates an expected call of Handle
func (mr *LogHandlerMockRecorder) Handle(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handle", reflect.TypeOf((*LogHandler)(nil).Handle), arg0)
}

View File

@@ -64,7 +64,7 @@ func (server *Server) handleConnection(conn net.Conn) {
}
pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
err := task.Run(task.Parallel(func() error {
err := task.Run(context.Background(), func() error {
defer pWriter.Close() // nolint: errcheck
for {
@@ -96,7 +96,7 @@ func (server *Server) handleConnection(conn net.Conn) {
return err
}
}
}))()
})
if err != nil {
fmt.Println("failed to transfer data: ", err.Error())

View File

@@ -10,7 +10,7 @@ import (
)
var (
effectiveSystemDialer SystemDialer = DefaultSystemDialer{}
effectiveSystemDialer SystemDialer = &DefaultSystemDialer{}
)
type SystemDialer interface {
@@ -18,23 +18,32 @@ type SystemDialer interface {
}
type DefaultSystemDialer struct {
controllers []controller
}
func (DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: time.Second * 60,
DualStack: true,
}
if sockopt != nil {
if sockopt != nil || len(d.controllers) > 0 {
dialer.Control = func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
if err := applyOutboundSocketOptions(network, address, fd, sockopt); err != nil {
newError("failed to apply socket options").Base(err).WriteToLog(session.ExportIDToError(ctx))
if sockopt != nil {
if err := applyOutboundSocketOptions(network, address, fd, sockopt); err != nil {
newError("failed to apply socket options").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
if dest.Network == net.Network_UDP && len(sockopt.BindAddress) > 0 && sockopt.BindPort > 0 {
if err := bindAddr(fd, sockopt.BindAddress, sockopt.BindPort); err != nil {
newError("failed to bind source address to ", sockopt.BindAddress).Base(err).WriteToLog(session.ExportIDToError(ctx))
}
}
}
if dest.Network == net.Network_UDP && len(sockopt.BindAddress) > 0 && sockopt.BindPort > 0 {
if err := bindAddr(fd, sockopt.BindAddress, sockopt.BindPort); err != nil {
newError("failed to bind source address to ", sockopt.BindAddress).Base(err).WriteToLog(session.ExportIDToError(ctx))
for _, ctl := range d.controllers {
if err := ctl(network, address, fd); err != nil {
newError("failed to apply external controller").Base(err).WriteToLog(session.ExportIDToError(ctx))
}
}
})
@@ -83,7 +92,26 @@ func (v *SimpleSystemDialer) Dial(ctx context.Context, src net.Address, dest net
// v2ray:api:stable
func UseAlternativeSystemDialer(dialer SystemDialer) {
if dialer == nil {
effectiveSystemDialer = DefaultSystemDialer{}
effectiveSystemDialer = &DefaultSystemDialer{}
}
effectiveSystemDialer = dialer
}
// RegisterDialerController adds a controller to the effective system dialer.
// The controller can be used to operate on file descriptors before they are put into use.
// It only works when effective dialer is the default dialer.
//
// v2ray:api:beta
func RegisterDialerController(ctl func(network, address string, fd uintptr) error) error {
if ctl == nil {
return newError("nil listener controller")
}
dialer, ok := effectiveSystemDialer.(*DefaultSystemDialer)
if !ok {
return newError("RegisterListenerController not supported in custom dialer")
}
dialer.controllers = append(dialer.controllers, ctl)
return nil
}

View File

@@ -1,6 +1,7 @@
package pipe_test
import (
"context"
"io"
"sync"
"testing"
@@ -73,7 +74,7 @@ func TestPipeLimitZero(t *testing.T) {
bb.Write([]byte{'a', 'b'})
assert(pWriter.WriteMultiBuffer(buf.MultiBuffer{bb}), IsNil)
err := task.Run(task.Parallel(func() error {
err := task.Run(context.Background(), func() error {
b := buf.New()
b.Write([]byte{'c', 'd'})
return pWriter.WriteMultiBuffer(buf.MultiBuffer{b})
@@ -91,7 +92,7 @@ func TestPipeLimitZero(t *testing.T) {
time.Sleep(time.Second * 2)
pWriter.Close()
return nil
}))()
})
assert(err, IsNil)
}

View File

@@ -8,7 +8,6 @@ import (
"v2ray.com/core/app/dispatcher"
"v2ray.com/core/app/proxyman"
"v2ray.com/core/common"
"v2ray.com/core/common/dice"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
@@ -19,6 +18,7 @@ import (
"v2ray.com/core/proxy/dokodemo"
"v2ray.com/core/proxy/vmess"
"v2ray.com/core/proxy/vmess/outbound"
"v2ray.com/core/testing/servers/tcp"
)
func TestV2RayDependency(t *testing.T) {
@@ -36,7 +36,8 @@ func TestV2RayDependency(t *testing.T) {
}
func TestV2RayClose(t *testing.T) {
port := net.Port(dice.RollUint16())
port := tcp.PickPort()
userId := uuid.New()
config := &Config{
App: []*serial.TypedMessage{