diff --git a/common/net/address.go b/common/net/address.go index f041c95c..5032f9dc 100644 --- a/common/net/address.go +++ b/common/net/address.go @@ -3,66 +3,142 @@ package net import ( "net" "strconv" + + "github.com/v2ray/v2ray-core/common/log" ) -const ( - AddrTypeIP = byte(0x01) - AddrTypeDomain = byte(0x03) -) +type Address interface { + IP() net.IP + Domain() string + Port() uint16 + PortBytes() []byte -type Address struct { - Type byte - IP net.IP - Domain string - Port uint16 + IsIPv4() bool + IsIPv6() bool + IsDomain() bool + + String() string } func IPAddress(ip []byte, port uint16) Address { - ipCopy := make([]byte, len(ip)) - copy(ipCopy, ip) - // TODO: check IP length - return Address{ - Type: AddrTypeIP, - IP: net.IP(ipCopy), - Domain: "", - Port: port, + switch len(ip) { + case net.IPv4len: + return IPv4Address{ + PortAddress: PortAddress{port: port}, + ip: [4]byte{ip[0], ip[1], ip[2], ip[3]}, + } + case net.IPv6len: + return IPv6Address{ + PortAddress: PortAddress{port: port}, + ip: [16]byte{ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], ip[8], ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15]}, + } + default: + panic(log.Error("Unknown IP format: %v", ip)) } } func DomainAddress(domain string, port uint16) Address { - return Address{ - Type: AddrTypeDomain, - IP: nil, - Domain: domain, - Port: port, + return DomainAddressImpl{ + domain: domain, + PortAddress: PortAddress{port: port}, } } -func (addr Address) IsIPv4() bool { - return addr.Type == AddrTypeIP && len(addr.IP) == net.IPv4len +type PortAddress struct { + port uint16 } -func (addr Address) IsIPv6() bool { - return addr.Type == AddrTypeIP && len(addr.IP) == net.IPv6len +func (addr PortAddress) Port() uint16 { + return addr.port } -func (addr Address) IsDomain() bool { - return addr.Type == AddrTypeDomain +func (addr PortAddress) PortBytes() []byte { + return []byte{byte(addr.port >> 8), byte(addr.port)} } -func (addr Address) String() string { - var host string - switch addr.Type { - case AddrTypeIP: - host = addr.IP.String() - if len(addr.IP) == net.IPv6len { - host = "[" + host + "]" - } - - case AddrTypeDomain: - host = addr.Domain - default: - panic("Unknown Address Type " + strconv.Itoa(int(addr.Type))) - } - return host + ":" + strconv.Itoa(int(addr.Port)) +type IPv4Address struct { + PortAddress + ip [4]byte +} + +func (addr IPv4Address) IP() net.IP { + return net.IP(addr.ip[:]) +} + +func (addr IPv4Address) Domain() string { + panic("Calling Domain() on an IPv4Address.") +} + +func (addr IPv4Address) IsIPv4() bool { + return true +} + +func (addr IPv4Address) IsIPv6() bool { + return false +} + +func (addr IPv4Address) IsDomain() bool { + return false +} + +func (addr IPv4Address) String() string { + return addr.IP().String() + ":" + strconv.Itoa(int(addr.PortAddress.port)) +} + +type IPv6Address struct { + PortAddress + ip [16]byte +} + +func (addr IPv6Address) IP() net.IP { + return net.IP(addr.ip[:]) +} + +func (addr IPv6Address) Domain() string { + panic("Calling Domain() on an IPv6Address.") +} + +func (addr IPv6Address) IsIPv4() bool { + return false +} + +func (addr IPv6Address) IsIPv6() bool { + return true +} + +func (addr IPv6Address) IsDomain() bool { + return false +} + +func (addr IPv6Address) String() string { + return "[" + addr.IP().String() + "]:" + strconv.Itoa(int(addr.PortAddress.port)) +} + +type DomainAddressImpl struct { + PortAddress + domain string +} + +func (addr DomainAddressImpl) IP() net.IP { + panic("Calling IP() on a DomainAddress.") +} + +func (addr DomainAddressImpl) Domain() string { + return addr.domain +} + +func (addr DomainAddressImpl) IsIPv4() bool { + return false +} + +func (addr DomainAddressImpl) IsIPv6() bool { + return false +} + +func (addr DomainAddressImpl) IsDomain() bool { + return true +} + +func (addr DomainAddressImpl) String() string { + return addr.domain + ":" + strconv.Itoa(int(addr.PortAddress.port)) } diff --git a/common/net/address_test.go b/common/net/address_test.go index f8af392c..9ddeb6f7 100644 --- a/common/net/address_test.go +++ b/common/net/address_test.go @@ -13,10 +13,9 @@ func TestIPv4Address(t *testing.T) { port := uint16(80) addr := IPAddress(ip, port) - assert.Byte(addr.Type).Equals(AddrTypeIP) assert.Bool(addr.IsIPv4()).IsTrue() - assert.Bytes(addr.IP).Equals(ip) - assert.Uint16(addr.Port).Equals(port) + assert.Bytes(addr.IP()).Equals(ip) + assert.Uint16(addr.Port()).Equals(port) assert.String(addr.String()).Equals("1.2.3.4:80") } @@ -32,10 +31,9 @@ func TestIPv6Address(t *testing.T) { port := uint16(443) addr := IPAddress(ip, port) - assert.Byte(addr.Type).Equals(AddrTypeIP) assert.Bool(addr.IsIPv6()).IsTrue() - assert.Bytes(addr.IP).Equals(ip) - assert.Uint16(addr.Port).Equals(port) + assert.Bytes(addr.IP()).Equals(ip) + assert.Uint16(addr.Port()).Equals(port) assert.String(addr.String()).Equals("[102:304:102:304:102:304:102:304]:443") } @@ -46,9 +44,8 @@ func TestDomainAddress(t *testing.T) { port := uint16(443) addr := DomainAddress(domain, port) - assert.Byte(addr.Type).Equals(AddrTypeDomain) assert.Bool(addr.IsDomain()).IsTrue() - assert.String(addr.Domain).Equals(domain) - assert.Uint16(addr.Port).Equals(port) + assert.String(addr.Domain()).Equals(domain) + assert.Uint16(addr.Port()).Equals(port) assert.String(addr.String()).Equals("v2ray.com:443") } diff --git a/proxy/vmess/protocol/vmess.go b/proxy/vmess/protocol/vmess.go index 570c6b69..13012521 100644 --- a/proxy/vmess/protocol/vmess.go +++ b/proxy/vmess/protocol/vmess.go @@ -211,22 +211,19 @@ func (request *VMessRequest) ToBytes(idHash user.CounterHash, randomRangeInt64 u buffer = append(buffer, request.RequestKey[:]...) buffer = append(buffer, request.ResponseHeader[:]...) buffer = append(buffer, request.Command) - - portBytes := make([]byte, 2) - binary.BigEndian.PutUint16(portBytes, request.Address.Port) - buffer = append(buffer, portBytes...) + buffer = append(buffer, request.Address.PortBytes()...) switch { case request.Address.IsIPv4(): buffer = append(buffer, addrTypeIPv4) - buffer = append(buffer, request.Address.IP...) + buffer = append(buffer, request.Address.IP()...) case request.Address.IsIPv6(): buffer = append(buffer, addrTypeIPv6) - buffer = append(buffer, request.Address.IP...) + buffer = append(buffer, request.Address.IP()...) case request.Address.IsDomain(): buffer = append(buffer, addrTypeDomain) - buffer = append(buffer, byte(len(request.Address.Domain))) - buffer = append(buffer, []byte(request.Address.Domain)...) + buffer = append(buffer, byte(len(request.Address.Domain()))) + buffer = append(buffer, []byte(request.Address.Domain())...) } paddingLength := mrand.Intn(32) + 1 diff --git a/proxy/vmess/vmessout.go b/proxy/vmess/vmessout.go index 5ed8ad54..4248801e 100644 --- a/proxy/vmess/vmessout.go +++ b/proxy/vmess/vmessout.go @@ -72,7 +72,7 @@ func startCommunicate(request *protocol.VMessRequest, dest *v2net.Destination, r input := ray.OutboundInput() output := ray.OutboundOutput() - conn, err := net.DialTCP(dest.Network(), nil, &net.TCPAddr{dest.Address().IP, int(dest.Address().Port), ""}) + conn, err := net.DialTCP(dest.Network(), nil, &net.TCPAddr{dest.Address().IP(), int(dest.Address().Port()), ""}) if err != nil { log.Error("Failed to open tcp (%s): %v", dest.String(), err) close(output)