diff --git a/app/reverse/portal.go b/app/reverse/portal.go index 7e3f2caf..c42b2825 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -72,7 +72,14 @@ func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) err } if isDomain(ob.Target, p.domain) { - muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{}) + opts := pipe.OptionsFromContext(ctx) + uplinkReader, uplinkWriter := pipe.New(opts...) + downlinkReader, downlinkWriter := pipe.New(opts...) + + muxClient, err := mux.NewClientWorker(transport.Link{ + Reader: uplinkReader, + Writer: downlinkWriter, + }, mux.ClientStrategy{}) if err != nil { return errors.New("failed to create mux client worker").Base(err).AtWarning() } @@ -84,11 +91,24 @@ func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) err p.picker.AddWorker(worker) - if _, ok := link.Reader.(*pipe.Reader); !ok { - select { - case <-ctx.Done(): - case <-muxClient.WaitClosed(): + inboundLink := &transport.Link{Reader: downlinkReader, Writer: uplinkWriter} + requestDone := func() error { + if err := buf.Copy(link.Reader, inboundLink.Writer); err != nil { + return errors.New("failed to transfer request").Base(err) } + return nil + } + responseDone := func() error { + if err := buf.Copy(inboundLink.Reader, link.Writer); err != nil { + return err + } + return nil + } + requestDonePost := task.OnSuccess(requestDone, task.Close(inboundLink.Writer)) + if err := task.Run(ctx, requestDonePost, responseDone); err != nil { + common.Interrupt(inboundLink.Reader) + common.Interrupt(inboundLink.Writer) + return errors.New("connection ends").Base(err) } return nil }