refactor dialer

pull/432/head
Darien Raymond 2017-02-23 23:48:47 +01:00
parent 2d83e260ac
commit 9cbc9b7170
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
7 changed files with 14 additions and 11 deletions

View File

@ -53,6 +53,6 @@ func Dial(ctx context.Context, dest v2net.Destination) (Connection, error) {
} }
// DialSystem calls system dialer to create a network connection. // DialSystem calls system dialer to create a network connection.
func DialSystem(src v2net.Address, dest v2net.Destination) (net.Conn, error) { func DialSystem(ctx context.Context, src v2net.Address, dest v2net.Destination) (net.Conn, error) {
return effectiveSystemDialer.Dial(src, dest) return effectiveSystemDialer.Dial(ctx, src, dest)
} }

View File

@ -7,6 +7,7 @@ import (
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
"v2ray.com/core/testing/servers/tcp" "v2ray.com/core/testing/servers/tcp"
. "v2ray.com/core/transport/internet" . "v2ray.com/core/transport/internet"
"context"
) )
func TestDialWithLocalAddr(t *testing.T) { func TestDialWithLocalAddr(t *testing.T) {
@ -17,7 +18,7 @@ func TestDialWithLocalAddr(t *testing.T) {
assert.Error(err).IsNil() assert.Error(err).IsNil()
defer server.Close() defer server.Close()
conn, err := DialSystem(net.LocalHostIP, net.TCPDestination(net.LocalHostIP, dest.Port)) conn, err := DialSystem(context.Background(), net.LocalHostIP, net.TCPDestination(net.LocalHostIP, dest.Port))
assert.Error(err).IsNil() assert.Error(err).IsNil()
assert.String(conn.RemoteAddr().String()).Equals("127.0.0.1:" + dest.Port.String()) assert.String(conn.RemoteAddr().String()).Equals("127.0.0.1:" + dest.Port.String())
conn.Close() conn.Close()

View File

@ -117,7 +117,7 @@ func DialKCP(ctx context.Context, dest v2net.Destination) (internet.Connection,
id := internal.NewConnectionID(src, dest) id := internal.NewConnectionID(src, dest)
conn := globalPool.Get(id) conn := globalPool.Get(id)
if conn == nil { if conn == nil {
rawConn, err := internet.DialSystem(src, dest) rawConn, err := internet.DialSystem(ctx, src, dest)
if err != nil { if err != nil {
log.Error("KCP|Dialer: Failed to dial to dest: ", err) log.Error("KCP|Dialer: Failed to dial to dest: ", err)
return nil, err return nil, err

View File

@ -4,6 +4,8 @@ import (
"net" "net"
"time" "time"
"context"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
) )
@ -12,13 +14,13 @@ var (
) )
type SystemDialer interface { type SystemDialer interface {
Dial(source v2net.Address, destination v2net.Destination) (net.Conn, error) Dial(ctx context.Context, source v2net.Address, destination v2net.Destination) (net.Conn, error)
} }
type DefaultSystemDialer struct { type DefaultSystemDialer struct {
} }
func (v *DefaultSystemDialer) Dial(src v2net.Address, dest v2net.Destination) (net.Conn, error) { func (v *DefaultSystemDialer) Dial(ctx context.Context, src v2net.Address, dest v2net.Destination) (net.Conn, error) {
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: time.Second * 60, Timeout: time.Second * 60,
DualStack: true, DualStack: true,
@ -38,7 +40,7 @@ func (v *DefaultSystemDialer) Dial(src v2net.Address, dest v2net.Destination) (n
} }
dialer.LocalAddr = addr dialer.LocalAddr = addr
} }
return dialer.Dial(dest.Network.SystemString(), dest.NetAddr()) return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr())
} }
type SystemDialerAdapter interface { type SystemDialerAdapter interface {
@ -55,7 +57,7 @@ func WithAdapter(dialer SystemDialerAdapter) SystemDialer {
} }
} }
func (v *SimpleSystemDialer) Dial(src v2net.Address, dest v2net.Destination) (net.Conn, error) { func (v *SimpleSystemDialer) Dial(ctx context.Context, src v2net.Address, dest v2net.Destination) (net.Conn, error) {
return v.adapter.Dial(dest.Network.SystemString(), dest.NetAddr()) return v.adapter.Dial(dest.Network.SystemString(), dest.NetAddr())
} }

View File

@ -31,7 +31,7 @@ func Dial(ctx context.Context, dest v2net.Destination) (internet.Connection, err
} }
if conn == nil { if conn == nil {
var err error var err error
conn, err = internet.DialSystem(src, dest) conn, err = internet.DialSystem(ctx, src, dest)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -13,7 +13,7 @@ func init() {
common.Must(internet.RegisterTransportDialer(internet.TransportProtocol_UDP, common.Must(internet.RegisterTransportDialer(internet.TransportProtocol_UDP,
func(ctx context.Context, dest v2net.Destination) (internet.Connection, error) { func(ctx context.Context, dest v2net.Destination) (internet.Connection, error) {
src := internet.DialerSourceFromContext(ctx) src := internet.DialerSourceFromContext(ctx)
conn, err := internet.DialSystem(src, dest) conn, err := internet.DialSystem(ctx, src, dest)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -47,7 +47,7 @@ func dialWebsocket(ctx context.Context, dest v2net.Destination) (net.Conn, error
wsSettings := internet.TransportSettingsFromContext(ctx).(*Config) wsSettings := internet.TransportSettingsFromContext(ctx).(*Config)
commonDial := func(network, addr string) (net.Conn, error) { commonDial := func(network, addr string) (net.Conn, error) {
return internet.DialSystem(src, dest) return internet.DialSystem(ctx, src, dest)
} }
dialer := websocket.Dialer{ dialer := websocket.Dialer{