diff --git a/common/net/destination.go b/common/net/destination.go index 99f243d4..17b301ca 100644 --- a/common/net/destination.go +++ b/common/net/destination.go @@ -2,6 +2,7 @@ package net import ( "net" + "strings" ) // Destination represents a network destination including address and protocol (tcp / udp). @@ -26,6 +27,37 @@ func DestinationFromAddr(addr net.Addr) Destination { } } +// ParseDestination converts a destination from its string presentation. +func ParseDestination(dest string) (Destination, error) { + d := Destination{ + Address: AnyIP, + Port: Port(0), + } + if strings.HasPrefix(dest, "tcp:") { + d.Network = Network_TCP + dest = dest[4:] + } else if strings.HasPrefix(dest, "udp:") { + d.Network = Network_UDP + dest = dest[4:] + } + + hstr, pstr, err := SplitHostPort(dest) + if err != nil { + return d, err + } + if len(hstr) > 0 { + d.Address = ParseAddress(hstr) + } + if len(pstr) > 0 { + port, err := PortFromString(pstr) + if err != nil { + return d, err + } + d.Port = port + } + return d, nil +} + // TCPDestination creates a TCP destination with given address func TCPDestination(address Address, port Port) Destination { return Destination{ diff --git a/common/net/destination_test.go b/common/net/destination_test.go index b68e9ccb..8c9cd9ad 100644 --- a/common/net/destination_test.go +++ b/common/net/destination_test.go @@ -25,3 +25,54 @@ func TestUDPDestination(t *testing.T) { assert(dest, IsUDP) assert(dest.String(), Equals, "udp:[2001:4860:4860::8888]:53") } + +func TestDestinationParse(t *testing.T) { + assert := With(t) + + cases := []struct { + Input string + Output Destination + Error bool + }{ + { + Input: "tcp:127.0.0.1:80", + Output: TCPDestination(LocalHostIP, Port(80)), + }, + { + Input: "udp:8.8.8.8:53", + Output: UDPDestination(IPAddress([]byte{8, 8, 8, 8}), Port(53)), + }, + { + Input: "8.8.8.8:53", + Output: Destination{ + Address: IPAddress([]byte{8, 8, 8, 8}), + Port: Port(53), + }, + }, + { + Input: ":53", + Output: Destination{ + Address: AnyIP, + Port: Port(53), + }, + }, + { + Input: "8.8.8.8", + Error: true, + }, + { + Input: "8.8.8.8:http", + Error: true, + }, + } + + for _, testcase := range cases { + d, err := ParseDestination(testcase.Input) + if !testcase.Error { + assert(err, IsNil) + assert(d, Equals, testcase.Output) + } else { + assert(err, IsNotNil) + } + } +} diff --git a/transport/internet/http/hub.go b/transport/internet/http/hub.go index c5bd8d79..89371fee 100644 --- a/transport/internet/http/hub.go +++ b/transport/internet/http/hub.go @@ -63,13 +63,25 @@ func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request) if f, ok := writer.(http.Flusher); ok { f.Flush() } + + remoteAddr := l.Addr() + dest, err := net.ParseDestination(request.RemoteAddr) + if err != nil { + newError("failed to parse request remote addr: ", request.RemoteAddr).Base(err).WriteToLog() + } else { + remoteAddr = &net.TCPAddr{ + IP: dest.Address.IP(), + Port: int(dest.Port), + } + } + done := signal.NewDone() conn := net.NewConnection( net.ConnectionOutput(request.Body), net.ConnectionInput(flushWriter{w: writer, d: done}), net.ConnectionOnClose(common.NewChainedClosable(done, request.Body)), net.ConnectionLocalAddr(l.Addr()), - net.ConnectionRemoteAddr(l.Addr()), + net.ConnectionRemoteAddr(remoteAddr), ) l.handler(conn) <-done.Wait()