diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index 0ce14408..8ed3b0e6 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -233,7 +233,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con sessionPolicy = s.policyManager.ForLevel(user.Level) if destination.Network == net.Network_UDP { // handle udp request - return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher) + return s.handleUDPPayload(ctx, sessionPolicy, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher) } ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ @@ -248,7 +248,11 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher) } -func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { +func (s *Server) handleUDPPayload(ctx context.Context, sessionPolicy policy.Session, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + defer timer.SetTimeout(0) udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { udpPayload := packet.Payload if udpPayload.UDP == nil { @@ -257,6 +261,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil { errors.LogWarningInner(ctx, err, "failed to write response") + cancel() + } else { + timer.Update() } }) defer udpServer.RemoveRay() @@ -266,47 +273,56 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade var dest *net.Destination - for { - select { - case <-ctx.Done(): - return nil - default: - mb, err := clientReader.ReadMultiBuffer() - if err != nil { - if errors.Cause(err) != io.EOF { - return errors.New("unexpected EOF").Base(err) - } + requestDone := func() error { + for { + select { + case <-ctx.Done(): return nil - } + default: + mb, err := clientReader.ReadMultiBuffer() + if err != nil { + if errors.Cause(err) != io.EOF { + return errors.New("unexpected EOF").Base(err) + } + return nil + } - mb2, b := buf.SplitFirst(mb) - if b == nil { - continue - } - destination := *b.UDP + mb2, b := buf.SplitFirst(mb) + if b == nil { + continue + } + timer.Update() + destination := *b.UDP - currentPacketCtx := ctx - if inbound.Source.IsValid() { - currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ - From: inbound.Source, - To: destination, - Status: log.AccessAccepted, - Reason: "", - Email: user.Email, - }) - } - errors.LogInfo(ctx, "tunnelling request to ", destination) + currentPacketCtx := ctx + if inbound.Source.IsValid() { + currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ + From: inbound.Source, + To: destination, + Status: log.AccessAccepted, + Reason: "", + Email: user.Email, + }) + } + errors.LogInfo(ctx, "tunnelling request to ", destination) - if !s.cone || dest == nil { - dest = &destination - } + if !s.cone || dest == nil { + dest = &destination + } - udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet - for _, payload := range mb2 { - udpServer.Dispatch(currentPacketCtx, *dest, payload) + udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet + for _, payload := range mb2 { + udpServer.Dispatch(currentPacketCtx, *dest, payload) + } } } + } + + if err := task.Run(ctx, requestDone); err != nil { + return err + } + return nil } func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session,