diff --git a/app/dispatcher/impl/default.go b/app/dispatcher/impl/default.go index 287c0209..2d1564e2 100644 --- a/app/dispatcher/impl/default.go +++ b/app/dispatcher/impl/default.go @@ -66,7 +66,7 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin outbound := ray.NewRay(ctx) sniferList := proxyman.ProtocoSniffersFromContext(ctx) - if len(sniferList) == 0 { + if destination.Address.Family().IsDomain() || len(sniferList) == 0 { go d.routedDispatch(ctx, outbound, destination) } else { go func() { @@ -75,7 +75,9 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin de := <-done if de.err != nil { log.Trace(newError("failed to snif").Base(de.err)) + return } + log.Trace(newError("sniffed domain: ", de.domain)) destination.Address = net.ParseAddress(de.domain) ctx = proxy.ContextWithTarget(ctx, destination) d.routedDispatch(ctx, outbound, destination) @@ -105,6 +107,9 @@ func snifer(ctx context.Context, sniferList []proxyman.KnownProtocols, outbound return } mb := outbound.OutboundInput().Peek() + if mb.IsEmpty() { + continue + } nBytes, _ := mb.Read(payload) for _, protocol := range sniferList { var f func([]byte) (string, error) @@ -126,6 +131,13 @@ func snifer(ctx context.Context, sniferList []proxyman.KnownProtocols, outbound return } } + if nBytes == 2048 { + done <- domainOrError{ + domain: "", + err: ErrInvalidData, + } + return + } } } } diff --git a/testing/scenarios/feature_test.go b/testing/scenarios/feature_test.go index 08f57851..9ed89f24 100644 --- a/testing/scenarios/feature_test.go +++ b/testing/scenarios/feature_test.go @@ -2,11 +2,14 @@ package scenarios import ( "net" + "net/http" + "net/url" "testing" "time" xproxy "golang.org/x/net/proxy" "v2ray.com/core" + "v2ray.com/core/app/log" "v2ray.com/core/app/proxyman" "v2ray.com/core/app/router" v2net "v2ray.com/core/common/net" @@ -16,6 +19,7 @@ import ( "v2ray.com/core/proxy/blackhole" "v2ray.com/core/proxy/dokodemo" "v2ray.com/core/proxy/freedom" + v2http "v2ray.com/core/proxy/http" "v2ray.com/core/proxy/socks" "v2ray.com/core/proxy/vmess" "v2ray.com/core/proxy/vmess/inbound" @@ -648,3 +652,93 @@ func TestUDPConnection(t *testing.T) { CloseAllServers() } + +func TestDomainSniffing(t *testing.T) { + assert := assert.On(t) + + sniffingPort := pickPort() + httpPort := pickPort() + serverConfig := &core.Config{ + Inbound: []*proxyman.InboundHandlerConfig{ + { + Tag: "snif", + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: v2net.SinglePortRange(sniffingPort), + Listen: v2net.NewIPOrDomain(v2net.LocalHostIP), + DomainOverride: []proxyman.KnownProtocols{ + proxyman.KnownProtocols_TLS, + }, + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: v2net.NewIPOrDomain(v2net.LocalHostIP), + Port: 443, + NetworkList: &v2net.NetworkList{ + Network: []v2net.Network{v2net.Network_TCP}, + }, + }), + }, + { + Tag: "http", + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: v2net.SinglePortRange(httpPort), + Listen: v2net.NewIPOrDomain(v2net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&v2http.ServerConfig{}), + }, + }, + Outbound: []*proxyman.OutboundHandlerConfig{ + { + Tag: "redir", + ProxySettings: serial.ToTypedMessage(&freedom.Config{ + DestinationOverride: &freedom.DestinationOverride{ + Server: &protocol.ServerEndpoint{ + Address: v2net.NewIPOrDomain(v2net.LocalHostIP), + Port: uint32(sniffingPort), + }, + }, + }), + }, + { + Tag: "direct", + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&router.Config{ + Rule: []*router.RoutingRule{ + { + Tag: "direct", + InboundTag: []string{"snif"}, + }, { + Tag: "redir", + InboundTag: []string{"http"}, + }, + }, + }), + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: log.LogLevel_Debug, + ErrorLogType: log.LogType_Console, + }), + }, + } + + assert.Error(InitializeServerConfig(serverConfig)).IsNil() + + { + transport := &http.Transport{ + Proxy: func(req *http.Request) (*url.URL, error) { + return url.Parse("http://127.0.0.1:" + httpPort.String()) + }, + } + + client := &http.Client{ + Transport: transport, + } + + resp, err := client.Get("https://www.github.com/") + assert.Error(err).IsNil() + assert.Int(resp.StatusCode).Equals(200) + } + + CloseAllServers() +}