diff --git a/common/net/packet.go b/common/net/packet.go index 42216669..8f235bf2 100644 --- a/common/net/packet.go +++ b/common/net/packet.go @@ -12,10 +12,11 @@ func NewTCPPacket(dest Destination) *TCPPacket { } } -func NewUDPPacket(dest Destination, data []byte) *UDPPacket { +func NewUDPPacket(dest Destination, data []byte, id uint16) *UDPPacket { return &UDPPacket{ basePacket: basePacket{destination: dest}, data: data, + id: id, } } @@ -42,6 +43,11 @@ func (packet *TCPPacket) MoreChunks() bool { type UDPPacket struct { basePacket data []byte + id uint16 +} + +func (packet *UDPPacket) ID() uint16 { + return packet.id } func (packet *UDPPacket) Chunk() []byte { diff --git a/proxy/socks/udp.go b/proxy/socks/udp.go index 4bdeea55..28f153ef 100644 --- a/proxy/socks/udp.go +++ b/proxy/socks/udp.go @@ -1,8 +1,12 @@ package socks import ( + "math" + "math/rand" "net" + "sync" + "github.com/v2ray/v2ray-core/common/collect" "github.com/v2ray/v2ray-core/common/log" v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/proxy/socks/protocol" @@ -12,6 +16,50 @@ const ( bufferSize = 2 * 1024 ) +type portMap struct { + access sync.Mutex + data map[uint16]*net.UDPAddr + removedPorts *collect.TimedQueue +} + +func newPortMap() *portMap { + m := &portMap{ + access: sync.Mutex{}, + data: make(map[uint16]*net.UDPAddr), + removedPorts: collect.NewTimedQueue(1), + } + go m.removePorts(m.removedPorts.RemovedEntries()) + return m +} + +func (m *portMap) assignAddressToken(addr *net.UDPAddr) uint16 { + for { + token := uint16(rand.Intn(math.MaxUint16)) + if _, used := m.data[token]; !used { + m.access.Lock() + if _, used = m.data[token]; !used { + m.data[token] = addr + m.access.Unlock() + return token + } + m.access.Unlock() + } + } +} + +func (m *portMap) removePorts(removedPorts <-chan interface{}) { + for { + rawToken := <-removedPorts + m.access.Lock() + delete(m.data, rawToken.(uint16)) + m.access.Unlock() + } +} + +var ( + ports = newPortMap() +) + func (server *SocksServer) ListenUDP(port uint16) error { addr := &net.UDPAddr{ IP: net.IP{0, 0, 0, 0}, @@ -31,7 +79,7 @@ func (server *SocksServer) ListenUDP(port uint16) error { func (server *SocksServer) AcceptPackets(conn *net.UDPConn) error { for { buffer := make([]byte, 0, bufferSize) - nBytes, _, err := conn.ReadFromUDP(buffer) + nBytes, addr, err := conn.ReadFromUDP(buffer) if err != nil { log.Error("Socks failed to read UDP packets: %v", err) return err @@ -46,7 +94,9 @@ func (server *SocksServer) AcceptPackets(conn *net.UDPConn) error { continue } - udpPacket := v2net.NewUDPPacket(request.Destination(), request.Data) + token := ports.assignAddressToken(addr) + + udpPacket := v2net.NewUDPPacket(request.Destination(), request.Data, token) server.vPoint.DispatchToOutbound(udpPacket) } }