Refine Trojan packet reader & writer (#142)

pull/152/head
maskedeken 2021-01-08 11:55:25 +08:00 committed by GitHub
parent 161e18299c
commit d5aeb6c545
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 50 deletions

View File

@ -146,26 +146,6 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
return nil return nil
} }
// WriteMultiBufferWithMetadata writes udp packet with destination specified
func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error {
for {
mb2, b := buf.SplitFirst(mb)
mb = mb2
if b == nil {
break
}
source := &dest
if b.UDP != nil {
source = b.UDP
}
if _, err := w.writePacket(b.Bytes(), *source); err != nil {
buf.ReleaseMulti(mb)
return err
}
}
return nil
}
func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) {
buffer := buf.StackNew() buffer := buf.StackNew()
defer buffer.Release() defer buffer.Release()
@ -259,12 +239,6 @@ func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
return buf.MultiBuffer{b}, err return buf.MultiBuffer{b}, err
} }
// PacketPayload combines udp payload and destination
type PacketPayload struct {
Target net.Destination
Buffer buf.MultiBuffer
}
// PacketReader is UDP Connection Reader Wrapper for trojan protocol // PacketReader is UDP Connection Reader Wrapper for trojan protocol
type PacketReader struct { type PacketReader struct {
io.Reader io.Reader
@ -272,15 +246,6 @@ type PacketReader struct {
// ReadMultiBuffer implements buf.Reader // ReadMultiBuffer implements buf.Reader
func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
p, err := r.ReadMultiBufferWithMetadata()
if p != nil {
return p.Buffer, err
}
return nil, err
}
// ReadMultiBufferWithMetadata reads udp packet with destination
func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
addr, port, err := addrParser.ReadAddressPort(nil, r) addr, port, err := addrParser.ReadAddressPort(nil, r)
if err != nil { if err != nil {
return nil, newError("failed to read address and port").Base(err) return nil, newError("failed to read address and port").Base(err)
@ -321,7 +286,7 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
remain -= int(n) remain -= int(n)
} }
return &PacketPayload{Target: dest, Buffer: mb}, nil return mb, nil
} }
func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn *xtls.Conn, rawConn syscall.RawConn, counter stats.Counter, sctx context.Context) error { func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn *xtls.Conn, rawConn syscall.RawConn, counter stats.Counter, sctx context.Context) error {

View File

@ -71,21 +71,22 @@ func TestUDPRequest(t *testing.T) {
common.Must(connReader.ParseHeader()) common.Must(connReader.ParseHeader())
packetReader := &PacketReader{Reader: connReader} packetReader := &PacketReader{Reader: connReader}
p, err := packetReader.ReadMultiBufferWithMetadata() mb, err := packetReader.ReadMultiBuffer()
common.Must(err) common.Must(err)
if p.Buffer.IsEmpty() { if mb.IsEmpty() {
t.Error("no request data") t.Error("no request data")
} }
if r := cmp.Diff(p.Target, destination); r != "" { mb2, b := buf.SplitFirst(mb)
defer buf.ReleaseMulti(mb2)
dest := *b.UDP
if r := cmp.Diff(dest, destination); r != "" {
t.Error("destination: ", r) t.Error("destination: ", r)
} }
mb, decoded := buf.SplitFirst(p.Buffer) if r := cmp.Diff(b.Bytes(), payload); r != "" {
buf.ReleaseMulti(mb)
if r := cmp.Diff(decoded.Bytes(), payload); r != "" {
t.Error("data: ", r) t.Error("data: ", r)
} }
} }

View File

@ -250,7 +250,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error { func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
common.Must(clientWriter.WriteMultiBufferWithMetadata(buf.MultiBuffer{packet.Payload}, packet.Source)) udpPayload := packet.Payload
udpPayload.UDP = &packet.Source
common.Must(clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}))
}) })
inbound := session.InboundFromContext(ctx) inbound := session.InboundFromContext(ctx)
@ -263,7 +265,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
default: default:
p, err := clientReader.ReadMultiBufferWithMetadata() mb, err := clientReader.ReadMultiBuffer()
if err != nil { if err != nil {
if errors.Cause(err) != io.EOF { if errors.Cause(err) != io.EOF {
return newError("unexpected EOF").Base(err) return newError("unexpected EOF").Base(err)
@ -271,21 +273,24 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
return nil return nil
} }
mb2, b := buf.SplitFirst(mb)
destination := *b.UDP
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: inbound.Source, From: inbound.Source,
To: p.Target, To: destination,
Status: log.AccessAccepted, Status: log.AccessAccepted,
Reason: "", Reason: "",
Email: user.Email, Email: user.Email,
}) })
newError("tunnelling request to ", p.Target).WriteToLog(session.ExportIDToError(ctx)) newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx))
if !buf.Cone || dest == nil { if !buf.Cone || dest == nil {
dest = &p.Target dest = &destination
} }
for _, b := range p.Buffer { udpServer.Dispatch(ctx, *dest, b) // first packet
udpServer.Dispatch(ctx, *dest, b) for _, payload := range mb2 {
udpServer.Dispatch(ctx, *dest, payload)
} }
} }
} }