From 9cbc9b717092672872924701e13ebb08f958c737 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Thu, 23 Feb 2017 23:48:47 +0100 Subject: [PATCH] refactor dialer --- transport/internet/dialer.go | 4 ++-- transport/internet/dialer_test.go | 3 ++- transport/internet/kcp/dialer.go | 2 +- transport/internet/system_dialer.go | 10 ++++++---- transport/internet/tcp/dialer.go | 2 +- transport/internet/udp/dialer.go | 2 +- transport/internet/websocket/dialer.go | 2 +- 7 files changed, 14 insertions(+), 11 deletions(-) diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 1452a843..547d21d1 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -53,6 +53,6 @@ func Dial(ctx context.Context, dest v2net.Destination) (Connection, error) { } // DialSystem calls system dialer to create a network connection. -func DialSystem(src v2net.Address, dest v2net.Destination) (net.Conn, error) { - return effectiveSystemDialer.Dial(src, dest) +func DialSystem(ctx context.Context, src v2net.Address, dest v2net.Destination) (net.Conn, error) { + return effectiveSystemDialer.Dial(ctx, src, dest) } diff --git a/transport/internet/dialer_test.go b/transport/internet/dialer_test.go index 41858065..77f02544 100644 --- a/transport/internet/dialer_test.go +++ b/transport/internet/dialer_test.go @@ -7,6 +7,7 @@ import ( "v2ray.com/core/testing/assert" "v2ray.com/core/testing/servers/tcp" . "v2ray.com/core/transport/internet" + "context" ) func TestDialWithLocalAddr(t *testing.T) { @@ -17,7 +18,7 @@ func TestDialWithLocalAddr(t *testing.T) { assert.Error(err).IsNil() defer server.Close() - conn, err := DialSystem(net.LocalHostIP, net.TCPDestination(net.LocalHostIP, dest.Port)) + conn, err := DialSystem(context.Background(), net.LocalHostIP, net.TCPDestination(net.LocalHostIP, dest.Port)) assert.Error(err).IsNil() assert.String(conn.RemoteAddr().String()).Equals("127.0.0.1:" + dest.Port.String()) conn.Close() diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index ad8604c1..61c5a0f3 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -117,7 +117,7 @@ func DialKCP(ctx context.Context, dest v2net.Destination) (internet.Connection, id := internal.NewConnectionID(src, dest) conn := globalPool.Get(id) if conn == nil { - rawConn, err := internet.DialSystem(src, dest) + rawConn, err := internet.DialSystem(ctx, src, dest) if err != nil { log.Error("KCP|Dialer: Failed to dial to dest: ", err) return nil, err diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index 9427bde8..f5cb0b49 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -4,6 +4,8 @@ import ( "net" "time" + "context" + v2net "v2ray.com/core/common/net" ) @@ -12,13 +14,13 @@ var ( ) type SystemDialer interface { - Dial(source v2net.Address, destination v2net.Destination) (net.Conn, error) + Dial(ctx context.Context, source v2net.Address, destination v2net.Destination) (net.Conn, error) } type DefaultSystemDialer struct { } -func (v *DefaultSystemDialer) Dial(src v2net.Address, dest v2net.Destination) (net.Conn, error) { +func (v *DefaultSystemDialer) Dial(ctx context.Context, src v2net.Address, dest v2net.Destination) (net.Conn, error) { dialer := &net.Dialer{ Timeout: time.Second * 60, DualStack: true, @@ -38,7 +40,7 @@ func (v *DefaultSystemDialer) Dial(src v2net.Address, dest v2net.Destination) (n } dialer.LocalAddr = addr } - return dialer.Dial(dest.Network.SystemString(), dest.NetAddr()) + return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr()) } type SystemDialerAdapter interface { @@ -55,7 +57,7 @@ func WithAdapter(dialer SystemDialerAdapter) SystemDialer { } } -func (v *SimpleSystemDialer) Dial(src v2net.Address, dest v2net.Destination) (net.Conn, error) { +func (v *SimpleSystemDialer) Dial(ctx context.Context, src v2net.Address, dest v2net.Destination) (net.Conn, error) { return v.adapter.Dial(dest.Network.SystemString(), dest.NetAddr()) } diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 7fa4eeab..7e5625f5 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -31,7 +31,7 @@ func Dial(ctx context.Context, dest v2net.Destination) (internet.Connection, err } if conn == nil { var err error - conn, err = internet.DialSystem(src, dest) + conn, err = internet.DialSystem(ctx, src, dest) if err != nil { return nil, err } diff --git a/transport/internet/udp/dialer.go b/transport/internet/udp/dialer.go index 1f940fba..e70220f3 100644 --- a/transport/internet/udp/dialer.go +++ b/transport/internet/udp/dialer.go @@ -13,7 +13,7 @@ func init() { common.Must(internet.RegisterTransportDialer(internet.TransportProtocol_UDP, func(ctx context.Context, dest v2net.Destination) (internet.Connection, error) { src := internet.DialerSourceFromContext(ctx) - conn, err := internet.DialSystem(src, dest) + conn, err := internet.DialSystem(ctx, src, dest) if err != nil { return nil, err } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index e66f5e5d..dec79a8b 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -47,7 +47,7 @@ func dialWebsocket(ctx context.Context, dest v2net.Destination) (net.Conn, error wsSettings := internet.TransportSettingsFromContext(ctx).(*Config) commonDial := func(network, addr string) (net.Conn, error) { - return internet.DialSystem(src, dest) + return internet.DialSystem(ctx, src, dest) } dialer := websocket.Dialer{