From 5bbece14afdc8db25eb61584b931c35c0389eceb Mon Sep 17 00:00:00 2001
From: Darien Raymond <admin@v2ray.com>
Date: Mon, 12 Mar 2018 16:21:39 +0100
Subject: [PATCH] simplify pool creation

---
 common/buf/buffer.go      | 41 +++---------------------------------
 common/buf/buffer_pool.go | 44 ++++++++++++++++++++++++++++++---------
 common/buf/reader.go      | 10 ++++-----
 common/buf/reader_test.go | 10 ++++++++-
 4 files changed, 50 insertions(+), 55 deletions(-)

diff --git a/common/buf/buffer.go b/common/buf/buffer.go
index f89b334e..356a9c85 100644
--- a/common/buf/buffer.go
+++ b/common/buf/buffer.go
@@ -23,7 +23,7 @@ func (b *Buffer) Release() {
 	if b == nil || b.v == nil {
 		return
 	}
-	FreeBytes(b.v)
+	freeBytes(b.v)
 	b.v = nil
 	b.start = 0
 	b.end = 0
@@ -175,48 +175,13 @@ func (b *Buffer) String() string {
 // New creates a Buffer with 0 length and 2K capacity.
 func New() *Buffer {
 	return &Buffer{
-		v: pool2k.Get().([]byte),
+		v: pool[0].Get().([]byte),
 	}
 }
 
 // NewSize creates and returns a buffer given capacity.
 func NewSize(size uint32) *Buffer {
 	return &Buffer{
-		v: NewBytes(size),
-	}
-}
-
-func NewBytes(size uint32) []byte {
-	if size > 128*1024 {
-		return make([]byte, size)
-	}
-
-	if size > 64*1024 {
-		return pool128k.Get().([]byte)
-	}
-
-	if size > 8*1024 {
-		return pool64k.Get().([]byte)
-	}
-
-	if size > 2*1024 {
-		return pool8k.Get().([]byte)
-	}
-
-	return pool2k.Get().([]byte)
-}
-
-func FreeBytes(b []byte) {
-	size := cap(b)
-	b = b[0:cap(b)]
-	switch {
-	case size >= 128*1024:
-		pool128k.Put(b)
-	case size >= 64*1024:
-		pool64k.Put(b)
-	case size >= 8*1024:
-		pool8k.Put(b)
-	case size >= 2*1024:
-		pool2k.Put(b)
+		v: newBytes(size),
 	}
 }
diff --git a/common/buf/buffer_pool.go b/common/buf/buffer_pool.go
index 92dcd34c..1c7a9ffe 100644
--- a/common/buf/buffer_pool.go
+++ b/common/buf/buffer_pool.go
@@ -15,18 +15,42 @@ func createAllocFunc(size uint32) func() interface{} {
 	}
 }
 
-var pool2k = &sync.Pool{
-	New: createAllocFunc(2 * 1024),
+const (
+	numPools  = 5
+	sizeMulti = 4
+)
+
+var (
+	pool     [numPools]*sync.Pool
+	poolSize [numPools]uint32
+)
+
+func init() {
+	size := uint32(Size)
+	for i := 0; i < numPools; i++ {
+		pool[i] = &sync.Pool{
+			New: createAllocFunc(size),
+		}
+		poolSize[i] = size
+		size *= sizeMulti
+	}
 }
 
-var pool8k = &sync.Pool{
-	New: createAllocFunc(8 * 1024),
+func newBytes(size uint32) []byte {
+	for idx, ps := range poolSize {
+		if size <= ps {
+			return pool[idx].Get().([]byte)
+		}
+	}
+	return make([]byte, size)
 }
 
-var pool64k = &sync.Pool{
-	New: createAllocFunc(64 * 1024),
-}
-
-var pool128k = &sync.Pool{
-	New: createAllocFunc(128 * 1024),
+func freeBytes(b []byte) {
+	size := uint32(cap(b))
+	for i := numPools - 1; i >= 0; i-- {
+		ps := poolSize[i]
+		if size >= ps {
+			pool[i].Put(b)
+		}
+	}
 }
diff --git a/common/buf/reader.go b/common/buf/reader.go
index f040a937..87f36808 100644
--- a/common/buf/reader.go
+++ b/common/buf/reader.go
@@ -19,15 +19,13 @@ func NewBytesToBufferReader(reader io.Reader) Reader {
 	}
 }
 
-const mediumSize = 8 * 1024
-const largeSize = 64 * 1024
 const xlSize = 128 * 1024
 
 func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) {
 	b := New()
 	err := b.Reset(ReadFrom(r.Reader))
 	if b.IsFull() {
-		r.buffer = NewBytes(mediumSize)
+		r.buffer = newBytes(Size + 1)
 	}
 	if !b.IsEmpty() {
 		return NewMultiBufferValue(b), nil
@@ -47,12 +45,12 @@ func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) {
 		mb := NewMultiBufferCap(nBytes/Size + 1)
 		mb.Write(r.buffer[:nBytes])
 		if nBytes == len(r.buffer) && nBytes < xlSize {
-			FreeBytes(r.buffer)
-			r.buffer = NewBytes(uint32(nBytes) + 1)
+			freeBytes(r.buffer)
+			r.buffer = newBytes(uint32(nBytes) + 1)
 		}
 		return mb, nil
 	}
-	FreeBytes(r.buffer)
+	freeBytes(r.buffer)
 	r.buffer = nil
 	return nil, err
 }
diff --git a/common/buf/reader_test.go b/common/buf/reader_test.go
index 9c90cfe0..f5964ff9 100644
--- a/common/buf/reader_test.go
+++ b/common/buf/reader_test.go
@@ -25,7 +25,15 @@ func TestAdaptiveReader(t *testing.T) {
 
 	b, err = reader.ReadMultiBuffer()
 	assert(err, IsNil)
-	assert(b.Len(), Equals, 64*1024)
+	assert(b.Len(), Equals, 32*1024)
+
+	b, err = reader.ReadMultiBuffer()
+	assert(err, IsNil)
+	assert(b.Len(), Equals, 128*1024)
+
+	b, err = reader.ReadMultiBuffer()
+	assert(err, IsNil)
+	assert(b.Len(), Equals, 128*1024)
 }
 
 func TestBytesReaderWriteTo(t *testing.T) {