diff --git a/common/mux/server.go b/common/mux/server.go index 70c5ed24..54120a47 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -5,7 +5,6 @@ import ( "io" "time" - "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -13,8 +12,11 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/session" + //"github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/signal/done" + "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/core" + //"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/pipe" @@ -64,14 +66,43 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t if dest.Address != muxCoolAddress { return s.dispatcher.DispatchLink(ctx, dest, link) } - link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link) - worker, err := NewServerWorker(ctx, s.dispatcher, link) + + // For Mux, we need to use pipe to guard against multiple sub-connections writing back responses at the same time + // sessionPolicy = h.policyManager.ForLevel(request.User.Level) + // ctx, cancel := context.WithCancel(ctx) + // timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + // ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) + opts := pipe.OptionsFromContext(ctx) + uplinkReader, uplinkWriter := pipe.New(opts...) + downlinkReader, downlinkWriter := pipe.New(opts...) + + _, err := NewServerWorker(ctx, s.dispatcher, &transport.Link{ + Reader: uplinkReader, + Writer: downlinkWriter, + }) if err != nil { return err } - select { - case <-ctx.Done(): - case <-worker.done.Wait(): + inboundLink := &transport.Link{Reader: downlinkReader, Writer: uplinkWriter} + requestDone := func() error { + //defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) + if err := buf.Copy(link.Reader, inboundLink.Writer); err != nil { + return errors.New("failed to transfer request").Base(err) + } + return nil + } + responseDone := func() error { + //defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) + if err := buf.Copy(inboundLink.Reader, link.Writer); err != nil { + return err + } + return nil + } + requestDonePost := task.OnSuccess(requestDone, task.Close(inboundLink.Writer)) + if err := task.Run(ctx, requestDonePost, responseDone); err != nil { + common.Interrupt(inboundLink.Reader) + common.Interrupt(inboundLink.Writer) + return errors.New("connection ends").Base(err) } return nil }