diff --git a/app/reverse/portal.go b/app/reverse/portal.go index 7e3f2caf..c42b2825 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -72,7 +72,14 @@ func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) err } if isDomain(ob.Target, p.domain) { - muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{}) + opts := pipe.OptionsFromContext(ctx) + uplinkReader, uplinkWriter := pipe.New(opts...) + downlinkReader, downlinkWriter := pipe.New(opts...) + + muxClient, err := mux.NewClientWorker(transport.Link{ + Reader: uplinkReader, + Writer: downlinkWriter, + }, mux.ClientStrategy{}) if err != nil { return errors.New("failed to create mux client worker").Base(err).AtWarning() } @@ -84,11 +91,24 @@ func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) err p.picker.AddWorker(worker) - if _, ok := link.Reader.(*pipe.Reader); !ok { - select { - case <-ctx.Done(): - case <-muxClient.WaitClosed(): + inboundLink := &transport.Link{Reader: downlinkReader, Writer: uplinkWriter} + requestDone := func() error { + if err := buf.Copy(link.Reader, inboundLink.Writer); err != nil { + return errors.New("failed to transfer request").Base(err) } + return nil + } + responseDone := func() error { + if err := buf.Copy(inboundLink.Reader, link.Writer); err != nil { + return err + } + return nil + } + requestDonePost := task.OnSuccess(requestDone, task.Close(inboundLink.Writer)) + if err := task.Run(ctx, requestDonePost, responseDone); err != nil { + common.Interrupt(inboundLink.Reader) + common.Interrupt(inboundLink.Writer) + return errors.New("connection ends").Base(err) } return nil } diff --git a/common/mux/server.go b/common/mux/server.go index 70c5ed24..54120a47 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -5,7 +5,6 @@ import ( "io" "time" - "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -13,8 +12,11 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/session" + //"github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/signal/done" + "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/core" + //"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/pipe" @@ -64,14 +66,43 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t if dest.Address != muxCoolAddress { return s.dispatcher.DispatchLink(ctx, dest, link) } - link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link) - worker, err := NewServerWorker(ctx, s.dispatcher, link) + + // For Mux, we need to use pipe to guard against multiple sub-connections writing back responses at the same time + // sessionPolicy = h.policyManager.ForLevel(request.User.Level) + // ctx, cancel := context.WithCancel(ctx) + // timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + // ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) + opts := pipe.OptionsFromContext(ctx) + uplinkReader, uplinkWriter := pipe.New(opts...) + downlinkReader, downlinkWriter := pipe.New(opts...) + + _, err := NewServerWorker(ctx, s.dispatcher, &transport.Link{ + Reader: uplinkReader, + Writer: downlinkWriter, + }) if err != nil { return err } - select { - case <-ctx.Done(): - case <-worker.done.Wait(): + inboundLink := &transport.Link{Reader: downlinkReader, Writer: uplinkWriter} + requestDone := func() error { + //defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) + if err := buf.Copy(link.Reader, inboundLink.Writer); err != nil { + return errors.New("failed to transfer request").Base(err) + } + return nil + } + responseDone := func() error { + //defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) + if err := buf.Copy(inboundLink.Reader, link.Writer); err != nil { + return err + } + return nil + } + requestDonePost := task.OnSuccess(requestDone, task.Close(inboundLink.Writer)) + if err := task.Run(ctx, requestDonePost, responseDone); err != nil { + common.Interrupt(inboundLink.Reader) + common.Interrupt(inboundLink.Writer) + return errors.New("connection ends").Base(err) } return nil } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 7975551b..eae12a7a 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -14,8 +14,6 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/signal" - "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/common/uuid" "github.com/xtls/xray-core/core" feature_inbound "github.com/xtls/xray-core/features/inbound" @@ -23,6 +21,7 @@ import ( "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/proxy/vmess" "github.com/xtls/xray-core/proxy/vmess/encoding" + "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet/stat" ) @@ -184,44 +183,6 @@ func (h *Handler) RemoveUser(ctx context.Context, email string) error { return nil } -func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input buf.Reader, output *buf.BufferedWriter) error { - session.EncodeResponseHeader(response, output) - - bodyWriter, err := session.EncodeResponseBody(request, output) - if err != nil { - return errors.New("failed to start decoding response").Base(err) - } - { - // Optimize for small response packet - data, err := input.ReadMultiBuffer() - if err != nil { - return err - } - - if err := bodyWriter.WriteMultiBuffer(data); err != nil { - return err - } - } - - if err := output.SetBuffered(false); err != nil { - return err - } - - if err := buf.Copy(input, bodyWriter, buf.UpdateActivity(timer)); err != nil { - return err - } - - account := request.User.Account.(*vmess.MemoryAccount) - - if request.Option.Has(protocol.RequestOptionChunkStream) && !account.NoTerminationSignal { - if err := bodyWriter.WriteMultiBuffer(buf.MultiBuffer{}); err != nil { - return err - } - } - - return nil -} - // Process implements proxy.Inbound.Process(). func (h *Handler) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { sessionPolicy := h.policyManager.ForLevel(0) @@ -275,49 +236,28 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s inbound.CanSpliceCopy = 3 inbound.User = request.User - sessionPolicy = h.policyManager.ForLevel(request.User.Level) - - ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) - - ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) - link, err := dispatcher.Dispatch(ctx, request.Destination()) + bodyReader, err := svrSession.DecodeRequestBody(request, reader) if err != nil { - return errors.New("failed to dispatch request to ", request.Destination()).Base(err) + return errors.New("failed to start decoding").Base(err) } - requestDone := func() error { - defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) - - bodyReader, err := svrSession.DecodeRequestBody(request, reader) - if err != nil { - return errors.New("failed to start decoding").Base(err) - } - if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil { - return errors.New("failed to transfer request").Base(err) - } - return nil + writer := buf.NewBufferedWriter(buf.NewWriter(connection)) + response := &protocol.ResponseHeader{ + Command: h.generateCommand(ctx, request), } - - responseDone := func() error { - defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) - - writer := buf.NewBufferedWriter(buf.NewWriter(connection)) - defer writer.Flush() - - response := &protocol.ResponseHeader{ - Command: h.generateCommand(ctx, request), - } - return transferResponse(timer, svrSession, request, response, link.Reader, writer) + svrSession.EncodeResponseHeader(response, writer) + bodyWriter, err := svrSession.EncodeResponseBody(request, writer) + if err != nil { + return errors.New("failed to start decoding response").Base(err) } + writer.SetFlushNext() - requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer)) - if err := task.Run(ctx, requestDonePost, responseDone); err != nil { - common.Interrupt(link.Reader) - common.Interrupt(link.Writer) - return errors.New("connection ends").Base(err) + if err := dispatcher.DispatchLink(ctx, request.Destination(), &transport.Link{ + Reader: bodyReader, + Writer: bodyWriter}, + ); err != nil { + return errors.New("failed to dispatch request").Base(err) } - return nil } diff --git a/testing/scenarios/vless_test.go b/testing/scenarios/vless_test.go index b699f497..95e1ad0b 100644 --- a/testing/scenarios/vless_test.go +++ b/testing/scenarios/vless_test.go @@ -121,6 +121,212 @@ func TestVless(t *testing.T) { } } +func TestVlessMuxTcp(t *testing.T) { + tcpServer := tcp.Server{ + MsgProcessor: xor, + } + dest, err := tcpServer.Start() + common.Must(err) + defer tcpServer.Close() + + userID := protocol.NewID(uuid.New()) + serverPort := tcp.PickPort() + serverConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Debug, + ErrorLogType: log.LogType_Console, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(serverPort)}}, + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&inbound.Config{ + Clients: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vless.Account{ + Id: userID.String(), + }), + }, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + } + + clientPort := tcp.PickPort() + clientConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Debug, + ErrorLogType: log.LogType_Console, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(clientPort)}}, + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: net.NewIPOrDomain(dest.Address), + Port: uint32(dest.Port), + Networks: []net.Network{net.Network_TCP}, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + SenderSettings: serial.ToTypedMessage(&proxyman.SenderConfig{ + MultiplexSettings: &proxyman.MultiplexingConfig{ + Enabled: true, + Concurrency: 4, + }, + }), + ProxySettings: serial.ToTypedMessage(&outbound.Config{ + Vnext: &protocol.ServerEndpoint{ + Address: net.NewIPOrDomain(net.LocalHostIP), + Port: uint32(serverPort), + User: &protocol.User{ + Account: serial.ToTypedMessage(&vless.Account{ + Id: userID.String(), + }), + }, + }, + }), + }, + }, + } + + servers, err := InitializeServerConfigs(serverConfig, clientConfig) + common.Must(err) + defer CloseAllServers(servers) + + for range "abcd" { + var errg errgroup.Group + for range 3 { + errg.Go(testTCPConn(clientPort, 10240, time.Second*20)) + } + if err := errg.Wait(); err != nil { + t.Fatal(err) + } + time.Sleep(time.Second) + } +} + +func TestVlessEncMuxTcp(t *testing.T) { + tcpServer := tcp.Server{ + MsgProcessor: xor, + } + dest, err := tcpServer.Start() + common.Must(err) + defer tcpServer.Close() + + userID := protocol.NewID(uuid.New()) + serverPort := tcp.PickPort() + serverConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Debug, + ErrorLogType: log.LogType_Console, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(serverPort)}}, + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&inbound.Config{ + Clients: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vless.Account{ + Id: userID.String(), + }), + }, + }, + SecondsFrom: 600, //mlkem768x25519plus.native.600s. + Decryption: "Gzh5Aa3Ibo3343XFC7V2a8ucOpFeGjOL6jMlBZAfjqyty2rdRms8xccBAm68imYw2q96gg2dcueeL2r7n_2YzQ", + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + } + + clientPort := tcp.PickPort() + clientConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Debug, + ErrorLogType: log.LogType_Console, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(clientPort)}}, + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: net.NewIPOrDomain(dest.Address), + Port: uint32(dest.Port), + Networks: []net.Network{net.Network_TCP}, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + SenderSettings: serial.ToTypedMessage(&proxyman.SenderConfig{ + MultiplexSettings: &proxyman.MultiplexingConfig{ + Enabled: true, + Concurrency: 4, + }, + }), + ProxySettings: serial.ToTypedMessage(&outbound.Config{ + Vnext: &protocol.ServerEndpoint{ + Address: net.NewIPOrDomain(net.LocalHostIP), + Port: uint32(serverPort), + User: &protocol.User{ + Account: serial.ToTypedMessage(&vless.Account{ + Id: userID.String(), + Seconds: 1, //mlkem768x25519plus.native.0rtt. + Encryption: "ExaMB4tIHpFikMeZwAJ8_8hxpZNi3gY13Ft455yC04xiCWgWUwMvKUwDQVm8zLcE8EKnjVlhRDmkTzMzvTMZyYlswCuqx0YK9kVNNFcrQJWD8JpAmTN8fffApIoWitDEAUTEp9S_Ehxo-9a2evRyKqJcQ6WmPiiyGbZrnNAfLKhdRsA15rZt6eKMVQExtDpucfaFc2E4-GtKzKd7P0I6bXccC1q4gqyZcXiEfOmmWBTPMTkNPEUdnQVsPiSWgJxslQZ5pYlPE7GQE7qoxYBItDMhkHZ4l0YwsvgZ1EQ2yTEn9DOxbyMihLk4kSAtg1IrW7tCTNkhyVsUY3SeyReB2sfN2AU-TXmVGUJMTKJ1jfywu8JIb9lG14HB1Rku6nVNcIMTzyshvsi_8AQFCSOcDdQ7ZpBxKxW7N1tKXBI0shq7vWdufjpYCjAVh-k_QgonVOwadYt-wPMxDntbWzEf_yC9eFQ6cBGd5smWNeSQZwAvqXw_WVPD56EVlaQ5HpsOkqBdy1Enr1NnH7WdgNsfk6RSQhRgW1dF9XBUKylpqsvOXkq3I0fLuuJFfuEZu4MeNvdgI2mbM_UxK8AzlRwkm7Eb1WQfm-S05HJefdZzu8kHYamggwtNQum_NtODzRgw3uWbjYbEBIY0j9IMhyGynOYQHHmrR2kT-dh08GwVD7BfsJRvFYgy2ZI8a3xGgHyi6MKKE8g7krEd-ne_4ddSaysgctaiiLwI4NVRbYJIT8XEbmKTIwoZx4R7m7AffYJo2NlfEPREg8stBcY5dAGXeSwD0pxs-jCJOeifQYq7Elq216SrwCmayLg3XJcpxutOmkhai6hRO6eBP6uy9XlLXyMt3TW6isx_rRt1hXCezkl_8hPEcqI9tPE0ZYVQ-eMh2_e35gQyPUw02aequ4ojaHV03QaSMquqF8RXG7k1gDed9vqex3aFaSN6UUNkebLKrqAiPmq0fccQ3qdbAxLGZ0ZFF5mIwEiFoTM6V4yPgntkRYtxcCKK-5YkPfsIunrM3EsWDCovp_Ahdfs-aqQLqzk1wVKTLQaQI5ApBlmGB3EauNdHFJBoeGZOF9e7QbGujhGRGMpS1fFtI2SqlcXINZU7YvR2JMfBrvBYZ9whXawM_Rg31IJR1raMGAEm6hNpa7SBD0cprIZxG6HKUQFMGHVlVohjwpWE5AGIc5Rc8Va2x8e3zFTMTUIwCdMz1XlNaqBMldJx01JQLwgSsnfGGlEJ_jYujvYNo0EBk4yev1Ap6nO-zSU-WtimlhEP0-cb22Q6e4wCEnWfO-lABJsrhwhrbloM51k5QVIefNyIvDWBszpRsreidUZVU4TOH2EoltYslWdPkcckfCplFLyvGKBItoAPRTOKRCjOsqlmj9OvpbDCzedZUmjLNfoLSwsPC7Nk2FpIkVUG6WxCE2YiU9LFrZIgWRKwUluM_at9w7wowRkujXEAQiJKtuUWQCxGyVbJtufLmQI6_yafmwgLoSlyE0cL-_Rf4nBCBjJnmyBDRvAoA-W08vw53uMt3RnFVwKFqo3PonmYAETv5rrMjh3L3K16QS-2EgL_R7WAFd0", + }), + }, + }, + }), + }, + }, + } + + servers, err := InitializeServerConfigs(serverConfig, clientConfig) + common.Must(err) + defer CloseAllServers(servers) + + for range "abcd" { + var errg errgroup.Group + for range 3 { + errg.Go(testTCPConn(clientPort, 10240, time.Second*20)) + } + if err := errg.Wait(); err != nil { + t.Fatal(err) + } + time.Sleep(time.Second) + } +} + func TestVlessTls(t *testing.T) { tcpServer := tcp.Server{ MsgProcessor: xor, diff --git a/testing/scenarios/vmess_test.go b/testing/scenarios/vmess_test.go index 402cf940..d290e1fb 100644 --- a/testing/scenarios/vmess_test.go +++ b/testing/scenarios/vmess_test.go @@ -1213,7 +1213,7 @@ func TestVMessGCMLengthAuthPlusNoTerminationSignal(t *testing.T) { { Account: serial.ToTypedMessage(&vmess.Account{ Id: userID.String(), - TestsEnabled: "AuthenticatedLength|NoTerminationSignal", + TestsEnabled: "AuthenticatedLength|", }), }, },