diff --git a/common/buf/reader.go b/common/buf/reader.go index 59d48709..fbe26310 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -38,12 +38,12 @@ func NewBytesToBufferReader(reader io.Reader) Reader { func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) { b, err := readOne(r.Reader) - if b.IsFull() && largeSize > Size { - r.buffer = newBytes(Size + 1) - } if err != nil { return nil, err } + if b.IsFull() && largeSize > Size { + r.buffer = newBytes(Size + 1) + } return NewMultiBufferValue(b), nil } diff --git a/common/buf/readv_reader.go b/common/buf/readv_reader.go index 5d8c2155..327cf97d 100644 --- a/common/buf/readv_reader.go +++ b/common/buf/readv_reader.go @@ -34,7 +34,7 @@ func allocN(n int32) []*Buffer { return bs } -func (r *ReadVReader) ReadMultiBuffer() (MultiBuffer, error) { +func (r *ReadVReader) readMulti() (MultiBuffer, error) { bs := allocN(r.nBuf) var iovecs []syscall.Iovec @@ -72,8 +72,6 @@ func (r *ReadVReader) ReadMultiBuffer() (MultiBuffer, error) { return nil, io.EOF } - var isFull bool = (nBytes == int(r.nBuf)*Size) - nBuf := 0 for nBuf < len(bs) { if nBytes <= 0 { @@ -93,13 +91,33 @@ func (r *ReadVReader) ReadMultiBuffer() (MultiBuffer, error) { bs[i] = nil } - if isFull && nBuf < 8 { - r.nBuf *= 4 - } else { - r.nBuf = int32(nBuf) + return MultiBuffer(bs[:nBuf]), nil +} + +// ReadMultiBuffer implements Reader. +func (r *ReadVReader) ReadMultiBuffer() (MultiBuffer, error) { + if r.nBuf == 1 { + b, err := readOne(r.Reader) + if err != nil { + return nil, err + } + if b.IsFull() { + r.nBuf = 2 + } + return NewMultiBufferValue(b), nil } - return MultiBuffer(bs[:nBuf]), nil + mb, err := r.readMulti() + if err != nil { + return nil, err + } + nBuf := int32(len(mb)) + if nBuf < r.nBuf { + r.nBuf = nBuf + } else if nBuf == r.nBuf && r.nBuf < 16 { + r.nBuf *= 4 + } + return mb, nil } var useReadv = false diff --git a/testing/scenarios/vmess_test.go b/testing/scenarios/vmess_test.go index 2ef36227..21a5626d 100644 --- a/testing/scenarios/vmess_test.go +++ b/testing/scenarios/vmess_test.go @@ -2,6 +2,8 @@ package scenarios import ( "crypto/rand" + "os" + "runtime" "sync" "testing" "time" @@ -9,6 +11,7 @@ import ( "v2ray.com/core" "v2ray.com/core/app/log" "v2ray.com/core/app/proxyman" + "v2ray.com/core/common" "v2ray.com/core/common/compare" clog "v2ray.com/core/common/log" "v2ray.com/core/common/net" @@ -261,13 +264,152 @@ func TestVMessGCM(t *testing.T) { */ servers, err := InitializeServerConfigs(serverConfig, clientConfig) - assert(err, IsNil) + if err != nil { + t.Fatal("Failed to initialize all servers: ", err.Error()) + } defer CloseAllServers(servers) var wg sync.WaitGroup wg.Add(10) for i := 0; i < 10; i++ { go func() { + defer wg.Done() + + conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: int(clientPort), + }) + assert(err, IsNil) + defer conn.Close() // nolint: errcheck + + payload := make([]byte, 10240*1024) + rand.Read(payload) + + nBytes, err := conn.Write([]byte(payload)) + assert(err, IsNil) + assert(nBytes, Equals, len(payload)) + + response := readFrom(conn, time.Second*40, 10240*1024) + if err := compare.BytesEqualWithDetail(response, xor([]byte(payload))); err != nil { + t.Error(err) + } + }() + } + wg.Wait() +} + +func TestVMessGCMReadv(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Not supported on Windows yet.") + return + } + assert := With(t) + + tcpServer := tcp.Server{ + MsgProcessor: xor, + } + dest, err := tcpServer.Start() + assert(err, IsNil) + defer tcpServer.Close() + + userID := protocol.NewID(uuid.New()) + serverPort := tcp.PickPort() + serverConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Debug, + ErrorLogType: log.LogType_Console, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: net.SinglePortRange(serverPort), + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&inbound.Config{ + User: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + AlterId: 64, + }), + }, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + } + + clientPort := tcp.PickPort() + clientConfig := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&log.Config{ + ErrorLogLevel: clog.Severity_Debug, + ErrorLogType: log.LogType_Console, + }), + }, + Inbound: []*core.InboundHandlerConfig{ + { + ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{ + PortRange: net.SinglePortRange(clientPort), + Listen: net.NewIPOrDomain(net.LocalHostIP), + }), + ProxySettings: serial.ToTypedMessage(&dokodemo.Config{ + Address: net.NewIPOrDomain(dest.Address), + Port: uint32(dest.Port), + NetworkList: &net.NetworkList{ + Network: []net.Network{net.Network_TCP}, + }, + }), + }, + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&outbound.Config{ + Receiver: []*protocol.ServerEndpoint{ + { + Address: net.NewIPOrDomain(net.LocalHostIP), + Port: uint32(serverPort), + User: []*protocol.User{ + { + Account: serial.ToTypedMessage(&vmess.Account{ + Id: userID.String(), + AlterId: 64, + SecuritySettings: &protocol.SecurityConfig{ + Type: protocol.SecurityType_AES128_GCM, + }, + }), + }, + }, + }, + }, + }), + }, + }, + } + + const envName = "V2RAY_BUF_READV" + common.Must(os.Setenv(envName, "1")) + defer os.Unsetenv(envName) + + servers, err := InitializeServerConfigs(serverConfig, clientConfig) + if err != nil { + t.Fatal("Failed to initialize all servers: ", err.Error()) + } + defer CloseAllServers(servers) + + var wg sync.WaitGroup + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ IP: []byte{127, 0, 0, 1}, Port: int(clientPort), @@ -286,7 +428,6 @@ func TestVMessGCM(t *testing.T) { if err := compare.BytesEqualWithDetail(response, xor([]byte(payload))); err != nil { t.Error(err) } - wg.Done() }() } wg.Wait()