implement WriteMultiBuffer

pull/1549/head
Darien Raymond 2017-12-03 22:53:00 +01:00
parent be714f76f1
commit b3e6994e52
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
6 changed files with 86 additions and 50 deletions

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"v2ray.com/core/app/log" "v2ray.com/core/app/log"
"v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/predicate" "v2ray.com/core/common/predicate"
) )
@ -343,10 +344,15 @@ func (v *Connection) Write(b []byte) (int, error) {
return totalWritten, io.ErrClosedPipe return totalWritten, io.ErrClosedPipe
} }
nBytes := v.sendingWorker.Push(b[totalWritten:]) for {
v.dataUpdater.WakeUp() rb := v.sendingWorker.Push()
if nBytes > 0 { if rb == nil {
totalWritten += nBytes break
}
common.Must(rb.Reset(func(bb []byte) (int, error) {
return copy(bb[:v.mss], b[totalWritten:]), nil
}))
totalWritten += rb.Len()
if totalWritten == len(b) { if totalWritten == len(b) {
return totalWritten, nil return totalWritten, nil
} }
@ -370,6 +376,45 @@ func (v *Connection) Write(b []byte) (int, error) {
} }
} }
func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release()
for {
if v == nil || v.State() != StateActive {
return io.ErrClosedPipe
}
for {
rb := v.sendingWorker.Push()
if rb == nil {
break
}
common.Must(rb.Reset(func(bb []byte) (int, error) {
return mb.Read(bb[:v.mss])
}))
if mb.IsEmpty() {
return nil
}
}
duration := time.Minute
if !v.wd.IsZero() {
duration = time.Until(v.wd)
if duration < 0 {
return ErrIOTimeout
}
}
select {
case <-v.dataOutput:
case <-time.After(duration):
if !v.wd.IsZero() && v.wd.Before(time.Now()) {
return ErrIOTimeout
}
}
}
}
func (v *Connection) SetState(state State) { func (v *Connection) SetState(state State) {
current := v.Elapsed() current := v.Elapsed()
atomic.StoreInt32((*int32)(&v.state), int32(state)) atomic.StoreInt32((*int32)(&v.state), int32(state))

View File

@ -214,8 +214,7 @@ func (w *ReceivingWorker) ReadMultiBuffer() buf.MultiBuffer {
} }
w.window.Advance() w.window.Advance()
w.nextNumber++ w.nextNumber++
mb.Append(seg.Data) mb.Append(seg.Detach())
seg.Data = nil
seg.Release() seg.Release()
} }

View File

@ -1,7 +1,6 @@
package kcp package kcp
import ( import (
"v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
) )
@ -44,8 +43,8 @@ type DataSegment struct {
Timestamp uint32 Timestamp uint32
Number uint32 Number uint32
SendingNext uint32 SendingNext uint32
Data *buf.Buffer
payload *buf.Buffer
timeout uint32 timeout uint32
transmit uint32 transmit uint32
} }
@ -62,13 +61,17 @@ func (v *DataSegment) Command() Command {
return CommandData return CommandData
} }
func (v *DataSegment) SetData(data []byte) { func (v *DataSegment) Detach() *buf.Buffer {
if v.Data == nil { r := v.payload
v.Data = buf.New() v.payload = nil
return r
} }
common.Must(v.Data.Reset(func(b []byte) (int, error) {
return copy(b, data), nil func (v *DataSegment) Data() *buf.Buffer {
})) if v.payload == nil {
v.payload = buf.New()
}
return v.payload
} }
func (v *DataSegment) Bytes() buf.Supplier { func (v *DataSegment) Bytes() buf.Supplier {
@ -78,19 +81,19 @@ func (v *DataSegment) Bytes() buf.Supplier {
b = serial.Uint32ToBytes(v.Timestamp, b) b = serial.Uint32ToBytes(v.Timestamp, b)
b = serial.Uint32ToBytes(v.Number, b) b = serial.Uint32ToBytes(v.Number, b)
b = serial.Uint32ToBytes(v.SendingNext, b) b = serial.Uint32ToBytes(v.SendingNext, b)
b = serial.Uint16ToBytes(uint16(v.Data.Len()), b) b = serial.Uint16ToBytes(uint16(v.payload.Len()), b)
b = append(b, v.Data.Bytes()...) b = append(b, v.payload.Bytes()...)
return len(b), nil return len(b), nil
} }
} }
func (v *DataSegment) ByteSize() int { func (v *DataSegment) ByteSize() int {
return 2 + 1 + 1 + 4 + 4 + 4 + 2 + v.Data.Len() return 2 + 1 + 1 + 4 + 4 + 4 + 2 + v.payload.Len()
} }
func (v *DataSegment) Release() { func (v *DataSegment) Release() {
v.Data.Release() v.payload.Release()
v.Data = nil v.payload = nil
} }
type AckSegment struct { type AckSegment struct {
@ -233,7 +236,8 @@ func ReadSegment(buf []byte) (Segment, []byte) {
if len(buf) < dataLen { if len(buf) < dataLen {
return nil, nil return nil, nil
} }
seg.SetData(buf[:dataLen]) seg.Data().Clear()
seg.Data().Append(buf[:dataLen])
buf = buf[dataLen:] buf = buf[dataLen:]
return seg, buf return seg, buf

View File

@ -3,7 +3,6 @@ package kcp_test
import ( import (
"testing" "testing"
"v2ray.com/core/common/buf"
. "v2ray.com/core/transport/internet/kcp" . "v2ray.com/core/transport/internet/kcp"
. "v2ray.com/ext/assert" . "v2ray.com/ext/assert"
) )
@ -19,15 +18,13 @@ func TestBadSegment(t *testing.T) {
func TestDataSegment(t *testing.T) { func TestDataSegment(t *testing.T) {
assert := With(t) assert := With(t)
b := buf.NewLocal(512)
b.Append([]byte{'a', 'b', 'c', 'd'})
seg := &DataSegment{ seg := &DataSegment{
Conv: 1, Conv: 1,
Timestamp: 3, Timestamp: 3,
Number: 4, Number: 4,
SendingNext: 5, SendingNext: 5,
Data: b,
} }
seg.Data().Append([]byte{'a', 'b', 'c', 'd'})
nBytes := seg.ByteSize() nBytes := seg.ByteSize()
bytes := make([]byte, nBytes) bytes := make([]byte, nBytes)
@ -41,21 +38,19 @@ func TestDataSegment(t *testing.T) {
assert(seg2.Timestamp, Equals, seg.Timestamp) assert(seg2.Timestamp, Equals, seg.Timestamp)
assert(seg2.SendingNext, Equals, seg.SendingNext) assert(seg2.SendingNext, Equals, seg.SendingNext)
assert(seg2.Number, Equals, seg.Number) assert(seg2.Number, Equals, seg.Number)
assert(seg2.Data.Bytes(), Equals, seg.Data.Bytes()) assert(seg2.Data().Bytes(), Equals, seg.Data().Bytes())
} }
func Test1ByteDataSegment(t *testing.T) { func Test1ByteDataSegment(t *testing.T) {
assert := With(t) assert := With(t)
b := buf.NewLocal(512)
b.AppendBytes('a')
seg := &DataSegment{ seg := &DataSegment{
Conv: 1, Conv: 1,
Timestamp: 3, Timestamp: 3,
Number: 4, Number: 4,
SendingNext: 5, SendingNext: 5,
Data: b,
} }
seg.Data().AppendBytes('a')
nBytes := seg.ByteSize() nBytes := seg.ByteSize()
bytes := make([]byte, nBytes) bytes := make([]byte, nBytes)
@ -69,7 +64,7 @@ func Test1ByteDataSegment(t *testing.T) {
assert(seg2.Timestamp, Equals, seg.Timestamp) assert(seg2.Timestamp, Equals, seg.Timestamp)
assert(seg2.SendingNext, Equals, seg.SendingNext) assert(seg2.SendingNext, Equals, seg.SendingNext)
assert(seg2.Number, Equals, seg.Number) assert(seg2.Number, Equals, seg.Number)
assert(seg2.Data.Bytes(), Equals, seg.Data.Bytes()) assert(seg2.Data().Bytes(), Equals, seg.Data().Bytes())
} }
func TestACKSegment(t *testing.T) { func TestACKSegment(t *testing.T) {

View File

@ -2,6 +2,8 @@ package kcp
import ( import (
"sync" "sync"
"v2ray.com/core/common/buf"
) )
type SendingWindow struct { type SendingWindow struct {
@ -62,9 +64,8 @@ func (sw *SendingWindow) IsFull() bool {
return sw.len == sw.cap return sw.len == sw.cap
} }
func (sw *SendingWindow) Push(number uint32, data []byte) { func (sw *SendingWindow) Push(number uint32) *buf.Buffer {
pos := (sw.start + sw.len) % sw.cap pos := (sw.start + sw.len) % sw.cap
sw.data[pos].SetData(data)
sw.data[pos].Number = number sw.data[pos].Number = number
sw.data[pos].timeout = 0 sw.data[pos].timeout = 0
sw.data[pos].transmit = 0 sw.data[pos].transmit = 0
@ -75,6 +76,7 @@ func (sw *SendingWindow) Push(number uint32, data []byte) {
} }
sw.last = pos sw.last = pos
sw.len++ sw.len++
return sw.data[pos].Data()
} }
func (sw *SendingWindow) FirstNumber() uint32 { func (sw *SendingWindow) FirstNumber() uint32 {
@ -224,7 +226,6 @@ func (v *SendingWorker) ProcessReceivingNextWithoutLock(nextNumber uint32) {
v.FindFirstUnacknowledged() v.FindFirstUnacknowledged()
} }
// Private: Visible for testing.
func (v *SendingWorker) FindFirstUnacknowledged() { func (v *SendingWorker) FindFirstUnacknowledged() {
first := v.firstUnacknowledged first := v.firstUnacknowledged
if !v.window.IsEmpty() { if !v.window.IsEmpty() {
@ -283,24 +284,16 @@ func (v *SendingWorker) ProcessSegment(current uint32, seg *AckSegment, rto uint
} }
} }
func (v *SendingWorker) Push(b []byte) int { func (v *SendingWorker) Push() *buf.Buffer {
nBytes := 0
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
for len(b) > 0 && !v.window.IsFull() { if !v.window.IsFull() {
var size int b := v.window.Push(v.nextNumber)
if len(b) > int(v.conn.mss) {
size = int(v.conn.mss)
} else {
size = len(b)
}
v.window.Push(v.nextNumber, b[:size])
v.nextNumber++ v.nextNumber++
b = b[size:] return b
nBytes += size
} }
return nBytes return nil
} }
// Private: Visible for testing. // Private: Visible for testing.

View File

@ -11,9 +11,9 @@ func TestSendingWindow(t *testing.T) {
assert := With(t) assert := With(t)
window := NewSendingWindow(5, nil, nil) window := NewSendingWindow(5, nil, nil)
window.Push(0, []byte{}) window.Push(0)
window.Push(1, []byte{}) window.Push(1)
window.Push(2, []byte{}) window.Push(2)
assert(window.Len(), Equals, 3) assert(window.Len(), Equals, 3)
window.Remove(1) window.Remove(1)
@ -27,11 +27,11 @@ func TestSendingWindow(t *testing.T) {
window.Remove(0) window.Remove(0)
assert(window.Len(), Equals, 0) assert(window.Len(), Equals, 0)
window.Push(4, []byte{}) window.Push(4)
assert(window.Len(), Equals, 1) assert(window.Len(), Equals, 1)
assert(window.FirstNumber(), Equals, uint32(4)) assert(window.FirstNumber(), Equals, uint32(4))
window.Push(5, []byte{}) window.Push(5)
assert(window.Len(), Equals, 2) assert(window.Len(), Equals, 2)
window.Remove(1) window.Remove(1)