From a1b552f9487d2687228dd1d89003087331f62eab Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 6 Jan 2019 09:39:37 +0100 Subject: [PATCH] use ListenPacket in Dial UDP connection --- transport/internet/system_dialer.go | 94 ++++++++++++++++++++++++----- 1 file changed, 79 insertions(+), 15 deletions(-) diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index e5259ad4..6588e7ea 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -21,10 +21,51 @@ type DefaultSystemDialer struct { controllers []controller } +func resolveSrcAddr(network net.Network, src net.Address) net.Addr { + if src == nil || src == net.AnyIP { + return nil + } + + if network == net.Network_TCP { + return &net.TCPAddr{ + IP: src.IP(), + Port: 0, + } + } + + return &net.UDPAddr{ + IP: src.IP(), + Port: 0, + } +} + func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination, sockopt *SocketConfig) (net.Conn, error) { + if dest.Network == net.Network_UDP { + srcAddr := resolveSrcAddr(net.Network_UDP, src) + if srcAddr == nil { + srcAddr = &net.UDPAddr{ + IP: []byte{0, 0, 0, 0}, + Port: 0, + } + } + packetConn, err := ListenSystemPacket(ctx, srcAddr, sockopt) + if err != nil { + return nil, err + } + destAddr, err := net.ResolveUDPAddr("udp", dest.NetAddr()) + if err != nil { + return nil, err + } + return &packetConnWrapper{ + conn: packetConn, + dest: destAddr, + }, nil + } + dialer := &net.Dialer{ Timeout: time.Second * 60, DualStack: true, + LocalAddr: resolveSrcAddr(dest.Network, src), } if sockopt != nil || len(d.controllers) > 0 { @@ -50,24 +91,47 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne } } - if src != nil && src != net.AnyIP { - var addr net.Addr - if dest.Network == net.Network_TCP { - addr = &net.TCPAddr{ - IP: src.IP(), - Port: 0, - } - } else { - addr = &net.UDPAddr{ - IP: src.IP(), - Port: 0, - } - } - dialer.LocalAddr = addr - } return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr()) } +type packetConnWrapper struct { + conn net.PacketConn + dest net.Addr +} + +func (c *packetConnWrapper) Close() error { + return c.conn.Close() +} + +func (c *packetConnWrapper) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *packetConnWrapper) RemoteAddr() net.Addr { + return c.dest +} + +func (c *packetConnWrapper) Write(p []byte) (int, error) { + return c.conn.WriteTo(p, c.dest) +} + +func (c *packetConnWrapper) Read(p []byte) (int, error) { + n, _, err := c.conn.ReadFrom(p) + return n, err +} + +func (c *packetConnWrapper) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *packetConnWrapper) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *packetConnWrapper) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + type SystemDialerAdapter interface { Dial(network string, address string) (net.Conn, error) }