diff --git a/common/protocol/http/headers.go b/common/protocol/http/headers.go new file mode 100644 index 00000000..c5c0644a --- /dev/null +++ b/common/protocol/http/headers.go @@ -0,0 +1,21 @@ +package http + +import ( + "net/http" + "strings" + + "v2ray.com/core/common/net" +) + +func ParseXForwardedFor(header http.Header) []net.Address { + xff := header.Get("X-Forwarded-For") + if len(xff) == 0 { + return nil + } + list := strings.Split(xff, ",") + addrs := make([]net.Address, 0, len(list)) + for _, proxy := range list { + addrs = append(addrs, net.ParseAddress(proxy)) + } + return addrs +} diff --git a/common/protocol/http/headers_test.go b/common/protocol/http/headers_test.go new file mode 100644 index 00000000..f1119a67 --- /dev/null +++ b/common/protocol/http/headers_test.go @@ -0,0 +1,20 @@ +package http_test + +import ( + "net/http" + "testing" + + . "v2ray.com/core/common/protocol/http" + . "v2ray.com/ext/assert" +) + +func TestParseXForwardedFor(t *testing.T) { + assert := With(t) + + header := http.Header{} + header.Add("X-Forwarded-For", "129.78.138.66, 129.78.64.103") + addrs := ParseXForwardedFor(header) + assert(len(addrs), Equals, 2) + assert(addrs[0].String(), Equals, "129.78.138.66") + assert(addrs[1].String(), Equals, "129.78.64.103") +} diff --git a/transport/internet/websocket/connection.go b/transport/internet/websocket/connection.go index d4e77432..3f08cf24 100644 --- a/transport/internet/websocket/connection.go +++ b/transport/internet/websocket/connection.go @@ -16,15 +16,16 @@ var ( // connection is a wrapper for net.Conn over WebSocket connection. type connection struct { - conn *websocket.Conn - reader io.Reader - + conn *websocket.Conn + reader io.Reader mergingWriter *buf.BufferedWriter + remoteAddr net.Addr } -func newConnection(conn *websocket.Conn) *connection { +func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection { return &connection{ - conn: conn, + conn: conn, + remoteAddr: remoteAddr, } } @@ -86,7 +87,7 @@ func (c *connection) LocalAddr() net.Addr { } func (c *connection) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() + return c.remoteAddr } func (c *connection) SetDeadline(t time.Time) error { diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 868e5285..6ecc8e3d 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -62,5 +62,5 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error) return nil, newError("failed to dial to (", uri, "): ", reason).Base(err) } - return newConnection(conn), nil + return newConnection(conn, conn.RemoteAddr()), nil } diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index 0226be78..4f4c3e46 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -12,6 +12,7 @@ import ( "v2ray.com/core/app/log" "v2ray.com/core/common" "v2ray.com/core/common/net" + http_proto "v2ray.com/core/common/protocol/http" "v2ray.com/core/transport/internet" v2tls "v2ray.com/core/transport/internet/tls" ) @@ -38,7 +39,13 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req return } - h.ln.addConn(h.ln.ctx, newConnection(conn)) + forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) + remoteAddr := conn.RemoteAddr() + if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().Either(net.AddressFamilyIPv4, net.AddressFamilyIPv6) { + remoteAddr.(*net.TCPAddr).IP = forwardedAddrs[0].IP() + } + + h.ln.addConn(h.ln.ctx, newConnection(conn, remoteAddr)) } type Listener struct { diff --git a/transport/internet/websocket/ws_test.go b/transport/internet/websocket/ws_test.go index ac97fbf8..678c18b7 100644 --- a/transport/internet/websocket/ws_test.go +++ b/transport/internet/websocket/ws_test.go @@ -59,15 +59,46 @@ func Test_listenWSAndDial(t *testing.T) { assert(err, IsNil) assert(string(b[:n]), Equals, "Response") assert(conn.Close(), IsNil) - <-time.After(time.Second * 15) - conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146)) + + assert(listen.Close(), IsNil) +} + +func TestDialWithRemoteAddr(t *testing.T) { + assert := With(t) + listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{ + Path: "ws", + }), net.DomainAddress("localhost"), 13146, func(ctx context.Context, conn internet.Connection) bool { + go func(c internet.Connection) { + defer c.Close() + + assert(c.RemoteAddr().String(), HasPrefix, "1.1.1.1") + + var b [1024]byte + n, err := c.Read(b[:]) + //assert(err, IsNil) + if err != nil { + return + } + assert(bytes.HasPrefix(b[:n], []byte("Test connection")), IsTrue) + + _, err = c.Write([]byte("Response")) + assert(err, IsNil) + }(conn) + return true + }) assert(err, IsNil) - _, err = conn.Write([]byte("Test connection 3")) + + ctx := internet.ContextWithTransportSettings(context.Background(), &Config{Path: "ws", Header: []*Header{{Key: "X-Forwarded-For", Value: "1.1.1.1"}}}) + conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146)) + assert(err, IsNil) - n, err = conn.Read(b[:]) + _, err = conn.Write([]byte("Test connection 1")) + assert(err, IsNil) + + var b [1024]byte + n, err := conn.Read(b[:]) assert(err, IsNil) assert(string(b[:n]), Equals, "Response") - assert(conn.Close(), IsNil) assert(listen.Close(), IsNil) }