diff --git a/app/reverse/portal.go b/app/reverse/portal.go index cec20d27..5104238a 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -84,11 +84,20 @@ func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) err p.picker.AddWorker(worker) if _, ok := link.Reader.(*pipe.Reader); !ok { - <-ctx.Done() // from DispatchLink() + select { + case <-ctx.Done(): + case <-muxClient.WaitClosed(): + } } return nil } + if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address { + link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address} + link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address} + } + + return p.client.Dispatch(ctx, link) } @@ -105,6 +114,7 @@ func (o *Outbound) Dispatch(ctx context.Context, link *transport.Link) { if err := o.portal.HandleConnection(ctx, link); err != nil { errors.LogInfoInner(ctx, err, "failed to process reverse connection") common.Interrupt(link.Writer) + common.Interrupt(link.Reader) } } diff --git a/common/mux/client.go b/common/mux/client.go index e94fd3ad..dddb6371 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -215,6 +215,10 @@ func (m *ClientWorker) Closed() bool { return m.done.Done() } +func (m *ClientWorker) WaitClosed() <-chan struct{} { + return m.done.Wait() +} + func (m *ClientWorker) GetTimer() *time.Ticker { return m.timer }