diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index ad3bcc32..650496f0 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -15,6 +15,7 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/errors" "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" "v2ray.com/core/proxy" "v2ray.com/core/transport/ray" ) @@ -173,37 +174,37 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool return true } -func drain(reader *Reader) error { - buf.Copy(reader, buf.Discard) +func drain(reader io.Reader) error { + buf.Copy(NewStreamReader(reader), buf.Discard) return nil } -func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader *Reader) error { +func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader io.Reader) error { if meta.Option.Has(OptionData) { return drain(reader) } return nil } -func (m *Client) handleStatusNew(meta *FrameMetadata, reader *Reader) error { +func (m *Client) handleStatusNew(meta *FrameMetadata, reader io.Reader) error { if meta.Option.Has(OptionData) { return drain(reader) } return nil } -func (m *Client) handleStatusKeep(meta *FrameMetadata, reader *Reader) error { +func (m *Client) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error { if !meta.Option.Has(OptionData) { return nil } if s, found := m.sessionManager.Get(meta.SessionID); found { - return buf.Copy(reader, s.output, buf.IgnoreWriterError()) + return buf.Copy(s.NewReader(reader), s.output, buf.IgnoreWriterError()) } return drain(reader) } -func (m *Client) handleStatusEnd(meta *FrameMetadata, reader *Reader) error { +func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error { if s, found := m.sessionManager.Get(meta.SessionID); found { s.CloseDownlink() s.output.Close() @@ -217,9 +218,11 @@ func (m *Client) handleStatusEnd(meta *FrameMetadata, reader *Reader) error { func (m *Client) fetchOutput() { defer m.cancel() - reader := NewReader(m.inboundRay.InboundOutput()) + reader := buf.ToBytesReader(m.inboundRay.InboundOutput()) + metaReader := NewMetadataReader(reader) + for { - meta, err := reader.ReadMetadata() + meta, err := metaReader.Read() if err != nil { if errors.Cause(err) != io.EOF { log.Trace(newError("failed to read metadata").Base(err)) @@ -289,7 +292,7 @@ type ServerWorker struct { } func handle(ctx context.Context, s *Session, output buf.Writer) { - writer := NewResponseWriter(s.ID, output) + writer := NewResponseWriter(s.ID, output, s.transferType) if err := buf.Copy(s.input, writer); err != nil { log.Trace(newError("session ", s.ID, " ends: ").Base(err)) } @@ -297,14 +300,14 @@ func handle(ctx context.Context, s *Session, output buf.Writer) { s.CloseDownlink() } -func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *Reader) error { +func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader io.Reader) error { if meta.Option.Has(OptionData) { return drain(reader) } return nil } -func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *Reader) error { +func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader io.Reader) error { log.Trace(newError("received request for ", meta.Target)) inboundRay, err := w.dispatcher.Dispatch(ctx, meta.Target) if err != nil { @@ -314,30 +317,34 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, return newError("failed to dispatch request.").Base(err) } s := &Session{ - input: inboundRay.InboundOutput(), - output: inboundRay.InboundInput(), - parent: w.sessionManager, - ID: meta.SessionID, + input: inboundRay.InboundOutput(), + output: inboundRay.InboundInput(), + parent: w.sessionManager, + ID: meta.SessionID, + transferType: protocol.TransferTypeStream, + } + if meta.Target.Network == net.Network_UDP { + s.transferType = protocol.TransferTypePacket } w.sessionManager.Add(s) go handle(ctx, s, w.outboundRay.OutboundOutput()) if meta.Option.Has(OptionData) { - return buf.Copy(reader, s.output, buf.IgnoreWriterError()) + return buf.Copy(s.NewReader(reader), s.output, buf.IgnoreWriterError()) } return nil } -func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *Reader) error { +func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error { if !meta.Option.Has(OptionData) { return nil } if s, found := w.sessionManager.Get(meta.SessionID); found { - return buf.Copy(reader, s.output, buf.IgnoreWriterError()) + return buf.Copy(s.NewReader(reader), s.output, buf.IgnoreWriterError()) } return drain(reader) } -func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *Reader) error { +func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error { if s, found := w.sessionManager.Get(meta.SessionID); found { s.CloseUplink() s.output.Close() @@ -348,8 +355,9 @@ func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *Reader) erro return nil } -func (w *ServerWorker) handleFrame(ctx context.Context, reader *Reader) error { - meta, err := reader.ReadMetadata() +func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error { + metaReader := NewMetadataReader(reader) + meta, err := metaReader.Read() if err != nil { return newError("failed to read metadata").Base(err) } @@ -375,7 +383,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *Reader) error { func (w *ServerWorker) run(ctx context.Context) { input := w.outboundRay.OutboundInput() - reader := NewReader(input) + reader := buf.ToBytesReader(input) defer w.sessionManager.Close() diff --git a/app/proxyman/mux/mux_test.go b/app/proxyman/mux/mux_test.go index b4501c64..f7977cf3 100644 --- a/app/proxyman/mux/mux_test.go +++ b/app/proxyman/mux/mux_test.go @@ -2,28 +2,45 @@ package mux_test import ( "context" + "io" "testing" . "v2ray.com/core/app/proxyman/mux" "v2ray.com/core/common/buf" "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" "v2ray.com/core/testing/assert" "v2ray.com/core/transport/ray" ) +func readAll(reader buf.Reader) (buf.MultiBuffer, error) { + mb := buf.NewMultiBuffer() + for { + b, err := reader.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + mb.AppendMulti(b) + } + return mb, nil +} + func TestReaderWriter(t *testing.T) { assert := assert.On(t) stream := ray.NewStream(context.Background()) dest := net.TCPDestination(net.DomainAddress("v2ray.com"), 80) - writer := NewWriter(1, dest, stream) + writer := NewWriter(1, dest, stream, protocol.TransferTypeStream) dest2 := net.TCPDestination(net.LocalHostIP, 443) - writer2 := NewWriter(2, dest2, stream) + writer2 := NewWriter(2, dest2, stream, protocol.TransferTypeStream) dest3 := net.TCPDestination(net.LocalHostIPv6, 18374) - writer3 := NewWriter(3, dest3, stream) + writer3 := NewWriter(3, dest3, stream, protocol.TransferTypeStream) writePayload := func(writer *Writer, payload ...byte) error { b := buf.New() @@ -43,73 +60,76 @@ func TestReaderWriter(t *testing.T) { assert.Error(writePayload(writer2, 'y')).IsNil() writer2.Close() - reader := NewReader(stream) - meta, err := reader.ReadMetadata() + bytesReader := buf.ToBytesReader(stream) + metaReader := NewMetadataReader(bytesReader) + streamReader := NewStreamReader(bytesReader) + + meta, err := metaReader.Read() assert.Error(err).IsNil() assert.Uint16(meta.SessionID).Equals(1) assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew)) assert.Destination(meta.Target).Equals(dest) assert.Byte(byte(meta.Option)).Equals(byte(OptionData)) - data, err := reader.Read() + data, err := readAll(streamReader) assert.Error(err).IsNil() assert.Int(len(data)).Equals(1) assert.String(data[0].String()).Equals("abcd") - meta, err = reader.ReadMetadata() + meta, err = metaReader.Read() assert.Error(err).IsNil() assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew)) assert.Uint16(meta.SessionID).Equals(2) assert.Byte(byte(meta.Option)).Equals(0) assert.Destination(meta.Target).Equals(dest2) - meta, err = reader.ReadMetadata() + meta, err = metaReader.Read() assert.Error(err).IsNil() assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusKeep)) assert.Uint16(meta.SessionID).Equals(1) assert.Byte(byte(meta.Option)).Equals(1) - data, err = reader.Read() + data, err = readAll(streamReader) assert.Error(err).IsNil() assert.Int(len(data)).Equals(1) assert.String(data[0].String()).Equals("efgh") - meta, err = reader.ReadMetadata() + meta, err = metaReader.Read() assert.Error(err).IsNil() assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew)) assert.Uint16(meta.SessionID).Equals(3) assert.Byte(byte(meta.Option)).Equals(1) assert.Destination(meta.Target).Equals(dest3) - data, err = reader.Read() + data, err = readAll(streamReader) assert.Error(err).IsNil() assert.Int(len(data)).Equals(1) assert.String(data[0].String()).Equals("x") - meta, err = reader.ReadMetadata() + meta, err = metaReader.Read() assert.Error(err).IsNil() assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd)) assert.Uint16(meta.SessionID).Equals(1) assert.Byte(byte(meta.Option)).Equals(0) - meta, err = reader.ReadMetadata() + meta, err = metaReader.Read() assert.Error(err).IsNil() assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd)) assert.Uint16(meta.SessionID).Equals(3) assert.Byte(byte(meta.Option)).Equals(0) - meta, err = reader.ReadMetadata() + meta, err = metaReader.Read() assert.Error(err).IsNil() assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusKeep)) assert.Uint16(meta.SessionID).Equals(2) assert.Byte(byte(meta.Option)).Equals(1) - data, err = reader.Read() + data, err = readAll(streamReader) assert.Error(err).IsNil() assert.Int(len(data)).Equals(1) assert.String(data[0].String()).Equals("y") - meta, err = reader.ReadMetadata() + meta, err = metaReader.Read() assert.Error(err).IsNil() assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd)) assert.Uint16(meta.SessionID).Equals(2) @@ -117,6 +137,6 @@ func TestReaderWriter(t *testing.T) { stream.Close() - meta, err = reader.ReadMetadata() + meta, err = metaReader.Read() assert.Error(err).IsNotNil() } diff --git a/app/proxyman/mux/reader.go b/app/proxyman/mux/reader.go index 278a05c4..d309f8df 100644 --- a/app/proxyman/mux/reader.go +++ b/app/proxyman/mux/reader.go @@ -7,57 +7,93 @@ import ( "v2ray.com/core/common/serial" ) -type Reader struct { +type MetadataReader struct { + reader io.Reader + buffer []byte +} + +func NewMetadataReader(reader io.Reader) *MetadataReader { + return &MetadataReader{ + reader: reader, + buffer: make([]byte, 1024), + } +} + +func (r *MetadataReader) Read() (*FrameMetadata, error) { + metaLen, err := serial.ReadUint16(r.reader) + if err != nil { + return nil, err + } + if metaLen > 512 { + return nil, newError("invalid metalen ", metaLen).AtWarning() + } + + if _, err := io.ReadFull(r.reader, r.buffer[:metaLen]); err != nil { + return nil, err + } + return ReadFrameFrom(r.buffer) +} + +type PacketReader struct { + reader io.Reader + eof bool +} + +func NewPacketReader(reader io.Reader) *PacketReader { + return &PacketReader{ + reader: reader, + eof: false, + } +} + +func (r *PacketReader) Read() (buf.MultiBuffer, error) { + if r.eof { + return nil, io.EOF + } + + size, err := serial.ReadUint16(r.reader) + if err != nil { + return nil, err + } + + var b *buf.Buffer + if size <= buf.Size { + b = buf.New() + } else { + b = buf.NewLocal(int(size)) + } + if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, int(size))); err != nil { + b.Release() + return nil, err + } + r.eof = true + return buf.NewMultiBufferValue(b), nil +} + +type StreamReader struct { reader io.Reader - buffer *buf.Buffer leftOver int } -func NewReader(reader buf.Reader) *Reader { - return &Reader{ - reader: buf.ToBytesReader(reader), - buffer: buf.NewLocal(1024), +func NewStreamReader(reader io.Reader) *StreamReader { + return &StreamReader{ + reader: reader, leftOver: -1, } } -func (r *Reader) ReadMetadata() (*FrameMetadata, error) { - r.leftOver = -1 - - b := r.buffer - b.Clear() - - if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, 2)); err != nil { - return nil, err - } - metaLen := serial.BytesToUint16(b.Bytes()) - if metaLen > 512 { - return nil, newError("invalid metalen ", metaLen).AtWarning() - } - b.Clear() - if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, int(metaLen))); err != nil { - return nil, err - } - return ReadFrameFrom(b.Bytes()) -} - -func (r *Reader) readSize() error { - if err := r.buffer.Reset(buf.ReadFullFrom(r.reader, 2)); err != nil { - return err - } - r.leftOver = int(serial.BytesToUint16(r.buffer.Bytes())) - return nil -} - -func (r *Reader) Read() (buf.MultiBuffer, error) { +func (r *StreamReader) Read() (buf.MultiBuffer, error) { if r.leftOver == 0 { r.leftOver = -1 return nil, io.EOF } + if r.leftOver == -1 { - if err := r.readSize(); err != nil { + size, err := serial.ReadUint16(r.reader) + if err != nil { return nil, err } + r.leftOver = int(size) } mb := buf.NewMultiBuffer() @@ -79,6 +115,5 @@ func (r *Reader) Read() (buf.MultiBuffer, error) { break } } - return mb, nil } diff --git a/app/proxyman/mux/session.go b/app/proxyman/mux/session.go index 325ee63c..80bb1b3c 100644 --- a/app/proxyman/mux/session.go +++ b/app/proxyman/mux/session.go @@ -1,8 +1,11 @@ package mux import ( + "io" "sync" + "v2ray.com/core/common/buf" + "v2ray.com/core/common/protocol" "v2ray.com/core/transport/ray" ) @@ -119,6 +122,7 @@ type Session struct { ID uint16 uplinkClosed bool downlinkClosed bool + transferType protocol.TransferType } func (s *Session) CloseUplink() { @@ -142,3 +146,10 @@ func (s *Session) CloseDownlink() { s.parent.Remove(s.ID) } } + +func (s *Session) NewReader(reader io.Reader) buf.Reader { + if s.transferType == protocol.TransferTypeStream { + return NewStreamReader(reader) + } + return NewPacketReader(reader) +} diff --git a/app/proxyman/mux/writer.go b/app/proxyman/mux/writer.go index 57dbd4a3..47c41914 100644 --- a/app/proxyman/mux/writer.go +++ b/app/proxyman/mux/writer.go @@ -5,30 +5,34 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" ) type Writer struct { - id uint16 - dest net.Destination - writer buf.Writer - followup bool + id uint16 + dest net.Destination + writer buf.Writer + followup bool + transferType protocol.TransferType } -func NewWriter(id uint16, dest net.Destination, writer buf.Writer) *Writer { +func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType) *Writer { return &Writer{ - id: id, - dest: dest, - writer: writer, - followup: false, + id: id, + dest: dest, + writer: writer, + followup: false, + transferType: transferType, } } -func NewResponseWriter(id uint16, writer buf.Writer) *Writer { +func NewResponseWriter(id uint16, writer buf.Writer, transferType protocol.TransferType) *Writer { return &Writer{ - id: id, - writer: writer, - followup: true, + id: id, + writer: writer, + followup: true, + transferType: transferType, } } @@ -82,13 +86,22 @@ func (w *Writer) Write(mb buf.MultiBuffer) error { return w.writeMetaOnly() } - const chunkSize = 8 * 1024 - for !mb.IsEmpty() { - slice := mb.SliceBySize(chunkSize) - if err := w.writeData(slice); err != nil { - return err + if w.transferType == protocol.TransferTypeStream { + const chunkSize = 8 * 1024 + for !mb.IsEmpty() { + slice := mb.SliceBySize(chunkSize) + if err := w.writeData(slice); err != nil { + return err + } + } + } else { + for _, b := range mb { + if err := w.writeData(buf.NewMultiBufferValue(b)); err != nil { + return err + } } } + return nil } diff --git a/common/crypto/auth.go b/common/crypto/auth.go index 12ed4ed7..fd1b735e 100644 --- a/common/crypto/auth.go +++ b/common/crypto/auth.go @@ -5,6 +5,7 @@ import ( "io" "v2ray.com/core/common/buf" + "v2ray.com/core/common/protocol" ) type BytesGenerator interface { @@ -60,34 +61,27 @@ func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) { return v.AEAD.Seal(dst, iv, plainText, additionalData), nil } -type StreamMode int - -const ( - ModeStream StreamMode = iota - ModePacket -) - type AuthenticationReader struct { - auth Authenticator - buffer *buf.Buffer - reader io.Reader - sizeParser ChunkSizeDecoder - size int - mode StreamMode + auth Authenticator + buffer *buf.Buffer + reader io.Reader + sizeParser ChunkSizeDecoder + size int + transferType protocol.TransferType } const ( readerBufferSize = 32 * 1024 ) -func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, mode StreamMode) *AuthenticationReader { +func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType) *AuthenticationReader { return &AuthenticationReader{ - auth: auth, - buffer: buf.NewLocal(readerBufferSize), - reader: reader, - sizeParser: sizeParser, - size: -1, - mode: mode, + auth: auth, + buffer: buf.NewLocal(readerBufferSize), + reader: reader, + sizeParser: sizeParser, + size: -1, + transferType: transferType, } } @@ -153,7 +147,7 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { } mb := buf.NewMultiBuffer() - if r.mode == ModeStream { + if r.transferType == protocol.TransferTypeStream { mb.Write(b) } else { var bb *buf.Buffer @@ -171,7 +165,7 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { if err != nil { break } - if r.mode == ModeStream { + if r.transferType == protocol.TransferTypeStream { mb.Write(b) } else { var bb *buf.Buffer @@ -189,22 +183,22 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { } type AuthenticationWriter struct { - auth Authenticator - payload []byte - buffer *buf.Buffer - writer io.Writer - sizeParser ChunkSizeEncoder - mode StreamMode + auth Authenticator + payload []byte + buffer *buf.Buffer + writer io.Writer + sizeParser ChunkSizeEncoder + transferType protocol.TransferType } -func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, mode StreamMode) *AuthenticationWriter { +func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter { return &AuthenticationWriter{ - auth: auth, - payload: make([]byte, 1024), - buffer: buf.NewLocal(readerBufferSize), - writer: writer, - sizeParser: sizeParser, - mode: mode, + auth: auth, + payload: make([]byte, 1024), + buffer: buf.NewLocal(readerBufferSize), + writer: writer, + sizeParser: sizeParser, + transferType: transferType, } } @@ -279,7 +273,7 @@ func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error { } func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error { - if w.mode == ModeStream { + if w.transferType == protocol.TransferTypeStream { return w.writeStream(mb) } diff --git a/common/crypto/auth_test.go b/common/crypto/auth_test.go index 4cd196d0..a44ecfbf 100644 --- a/common/crypto/auth_test.go +++ b/common/crypto/auth_test.go @@ -9,6 +9,7 @@ import ( "v2ray.com/core/common/buf" . "v2ray.com/core/common/crypto" + "v2ray.com/core/common/protocol" "v2ray.com/core/testing/assert" ) @@ -39,7 +40,7 @@ func TestAuthenticationReaderWriter(t *testing.T) { Content: iv, }, AdditionalDataGenerator: &NoOpBytesGenerator{}, - }, PlainChunkSizeParser{}, cache, ModeStream) + }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream) assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil() assert.Int(cache.Len()).Equals(83360) @@ -52,7 +53,7 @@ func TestAuthenticationReaderWriter(t *testing.T) { Content: iv, }, AdditionalDataGenerator: &NoOpBytesGenerator{}, - }, PlainChunkSizeParser{}, cache, ModeStream) + }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream) mb := buf.NewMultiBuffer() @@ -92,7 +93,7 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) { Content: iv, }, AdditionalDataGenerator: &NoOpBytesGenerator{}, - }, PlainChunkSizeParser{}, cache, ModePacket) + }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket) payload := buf.NewMultiBuffer() pb1 := buf.New() @@ -114,7 +115,7 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) { Content: iv, }, AdditionalDataGenerator: &NoOpBytesGenerator{}, - }, PlainChunkSizeParser{}, cache, ModePacket) + }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket) mb, err := reader.Read() assert.Error(err).IsNil() diff --git a/common/protocol/headers.go b/common/protocol/headers.go index 259af8f9..9ea26aa5 100644 --- a/common/protocol/headers.go +++ b/common/protocol/headers.go @@ -15,6 +15,14 @@ const ( RequestCommandUDP = RequestCommand(0x02) ) +func (c RequestCommand) TransferType() TransferType { + if c == RequestCommandTCP { + return TransferTypeStream + } + + return TransferTypePacket +} + // RequestOption is the options of a request. type RequestOption byte diff --git a/common/protocol/payload.go b/common/protocol/payload.go new file mode 100644 index 00000000..82214301 --- /dev/null +++ b/common/protocol/payload.go @@ -0,0 +1,8 @@ +package protocol + +type TransferType int + +const ( + TransferTypeStream TransferType = 0 + TransferTypePacket TransferType = 1 +) diff --git a/common/serial/numbers.go b/common/serial/numbers.go index 54ebda72..a81765e7 100644 --- a/common/serial/numbers.go +++ b/common/serial/numbers.go @@ -1,6 +1,7 @@ package serial import "strconv" +import "io" // Uint16ToBytes serializes an uint16 into bytes in big endian order. func Uint16ToBytes(value uint16, b []byte) []byte { @@ -11,6 +12,14 @@ func Uint16ToString(value uint16) string { return strconv.Itoa(int(value)) } +func ReadUint16(reader io.Reader) (uint16, error) { + var b [2]byte + if _, err := io.ReadFull(reader, b[:]); err != nil { + return 0, err + } + return BytesToUint16(b[:]), nil +} + func WriteUint16(value uint16) func([]byte) (int, error) { return func(b []byte) (int, error) { b = Uint16ToBytes(value, b[:0]) diff --git a/proxy/vmess/encoding/auth.go b/proxy/vmess/encoding/auth.go index bd1ee005..b96ed5b8 100644 --- a/proxy/vmess/encoding/auth.go +++ b/proxy/vmess/encoding/auth.go @@ -6,8 +6,6 @@ import ( "golang.org/x/crypto/sha3" - "v2ray.com/core/common/crypto" - "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" ) @@ -108,11 +106,3 @@ func (s *ShakeSizeParser) Encode(size uint16, b []byte) []byte { mask := s.next() return serial.Uint16ToBytes(mask^size, b[:0]) } - -func GetStreamMode(request *protocol.RequestHeader) crypto.StreamMode { - if request.Command == protocol.RequestCommandTCP { - return crypto.ModeStream - } - - return crypto.ModePacket -} diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index baa32c16..28437d8d 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -131,7 +131,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write NonceGenerator: crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, crypto.ModePacket) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket) } return buf.NewWriter(writer) @@ -146,7 +146,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write NonceGenerator: crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationWriter(auth, sizeParser, cryptionWriter, GetStreamMode(request)) + return crypto.NewAuthenticationWriter(auth, sizeParser, cryptionWriter, request.Command.TransferType()) } return buf.NewWriter(cryptionWriter) @@ -164,7 +164,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request)) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType()) } if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { @@ -178,7 +178,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request)) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType()) } panic("Unknown security type.") @@ -239,7 +239,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, crypto.ModePacket) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket) } return buf.NewReader(reader) @@ -252,7 +252,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read NonceGenerator: crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationReader(auth, sizeParser, v.responseReader, GetStreamMode(request)) + return crypto.NewAuthenticationReader(auth, sizeParser, v.responseReader, request.Command.TransferType()) } return buf.NewReader(v.responseReader) @@ -270,7 +270,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request)) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType()) } if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { @@ -284,7 +284,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request)) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType()) } panic("Unknown security type.") diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index a11249b5..dde8f76f 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -249,7 +249,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade NonceGenerator: crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, crypto.ModePacket) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket) } return buf.NewReader(reader) @@ -264,7 +264,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade NonceGenerator: crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader, GetStreamMode(request)) + return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader, request.Command.TransferType()) } return buf.NewReader(cryptionReader) @@ -282,7 +282,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request)) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType()) } if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { @@ -296,7 +296,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request)) + return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType()) } panic("Unknown security type.") @@ -335,7 +335,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ NonceGenerator: &crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, crypto.ModePacket) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket) } return buf.NewWriter(writer) @@ -348,7 +348,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ NonceGenerator: crypto.NoOpBytesGenerator{}, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationWriter(auth, sizeParser, v.responseWriter, GetStreamMode(request)) + return crypto.NewAuthenticationWriter(auth, sizeParser, v.responseWriter, request.Command.TransferType()) } return buf.NewWriter(v.responseWriter) @@ -366,7 +366,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request)) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType()) } if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { @@ -380,7 +380,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, } - return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request)) + return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType()) } panic("Unknown security type.")