diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index beb2c06e..be9b8657 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -77,12 +77,12 @@ func NewConnection(conv uint32, writerCloser io.WriteCloser, local *net.UDPAddr, conn.block = block conn.writer = writerCloser conn.since = nowMillisec() + conn.writeBufferSize = effectiveConfig.WriteBuffer / effectiveConfig.Mtu mtu := effectiveConfig.Mtu - uint32(block.HeaderSize()) - headerSize - conn.kcp = NewKCP(conv, mtu, effectiveConfig.GetSendingWindowSize(), effectiveConfig.GetReceivingWindowSize(), conn.output) + conn.kcp = NewKCP(conv, mtu, effectiveConfig.GetSendingWindowSize(), effectiveConfig.GetReceivingWindowSize(), conn.writeBufferSize, conn.output) conn.kcp.NoDelay(effectiveConfig.Tti, 2, effectiveConfig.Congestion) conn.kcp.current = conn.Elapsed() - conn.writeBufferSize = effectiveConfig.WriteBuffer / effectiveConfig.Mtu go conn.updateTask() @@ -133,6 +133,7 @@ func (this *Connection) Write(b []byte) (int, error) { this.state == ConnStateClosed { return 0, io.ErrClosedPipe } + totalWritten := 0 for { this.RLock() @@ -140,23 +141,26 @@ func (this *Connection) Write(b []byte) (int, error) { this.state == ConnStatePeerClosed || this.state == ConnStateClosed { this.RUnlock() - return 0, io.ErrClosedPipe + return totalWritten, io.ErrClosedPipe } this.RUnlock() this.kcpAccess.Lock() - if this.kcp.WaitSnd() < this.writeBufferSize { - nBytes := len(b) - this.kcp.Send(b) + nBytes := this.kcp.Send(b[totalWritten:]) + if nBytes > 0 { this.kcp.current = this.Elapsed() this.kcp.flush() - this.kcpAccess.Unlock() - return nBytes, nil + totalWritten += nBytes + if totalWritten == len(b) { + this.kcpAccess.Unlock() + return totalWritten, nil + } } + this.kcpAccess.Unlock() if !this.wd.IsZero() && this.wd.Before(time.Now()) { - return 0, errTimeout + return totalWritten, errTimeout } // Sending windows is 1024 for the moment. This amount is not gonna sent in 1 sec. diff --git a/transport/internet/kcp/kcp.go b/transport/internet/kcp/kcp.go index 6be774fc..ab4c0886 100644 --- a/transport/internet/kcp/kcp.go +++ b/transport/internet/kcp/kcp.go @@ -146,7 +146,7 @@ type KCP struct { ts_probe, probe_wait uint32 dead_link, incr uint32 - snd_queue []*Segment + snd_queue *SendingQueue rcv_queue []*Segment snd_buf []*Segment rcv_buf *ReceivingWindow @@ -161,7 +161,7 @@ type KCP struct { // NewKCP create a new kcp control object, 'conv' must equal in two endpoint // from the same connection. -func NewKCP(conv uint32, mtu uint32, sendingWindowSize uint32, receivingWindowSize uint32, output Output) *KCP { +func NewKCP(conv uint32, mtu uint32, sendingWindowSize uint32, receivingWindowSize uint32, sendingQueueSize uint32, output Output) *KCP { kcp := new(KCP) kcp.conv = conv kcp.snd_wnd = sendingWindowSize @@ -177,6 +177,7 @@ func NewKCP(conv uint32, mtu uint32, sendingWindowSize uint32, receivingWindowSi kcp.dead_link = IKCP_DEADLINK kcp.output = output kcp.rcv_buf = NewReceivingWindow(receivingWindowSize) + kcp.snd_queue = NewSendingQueue(sendingQueueSize) return kcp } @@ -232,26 +233,8 @@ func (kcp *KCP) DumpReceivingBuf() { // Send is user/upper level send, returns below zero for error func (kcp *KCP) Send(buffer []byte) int { - var count int - if len(buffer) == 0 { - return -1 - } - - if len(buffer) < int(kcp.mss) { - count = 1 - } else { - count = (len(buffer) + int(kcp.mss) - 1) / int(kcp.mss) - } - - if count > 255 { - return -2 - } - - if count == 0 { - count = 1 - } - - for i := 0; i < count; i++ { + nBytes := 0 + for len(buffer) > 0 && !kcp.snd_queue.IsFull() { var size int if len(buffer) > int(kcp.mss) { size = int(kcp.mss) @@ -260,11 +243,11 @@ func (kcp *KCP) Send(buffer []byte) int { } seg := NewSegment() seg.data.Append(buffer[:size]) - seg.frg = uint32(count - i - 1) - kcp.snd_queue = append(kcp.snd_queue, seg) + kcp.snd_queue.Push(seg) buffer = buffer[size:] + nBytes += size } - return 0 + return nBytes } // https://tools.ietf.org/html/rfc6298 @@ -572,12 +555,8 @@ func (kcp *KCP) flush() { cwnd = _imin_(kcp.cwnd, cwnd) } - count = 0 - for k := range kcp.snd_queue { - if _itimediff(kcp.snd_nxt, cwnd) >= 0 { - break - } - newseg := kcp.snd_queue[k] + for !kcp.snd_queue.IsEmpty() && _itimediff(kcp.snd_nxt, cwnd) < 0 { + newseg := kcp.snd_queue.Pop() newseg.conv = kcp.conv newseg.cmd = IKCP_CMD_PUSH newseg.wnd = seg.wnd @@ -589,9 +568,7 @@ func (kcp *KCP) flush() { newseg.xmit = 0 kcp.snd_buf = append(kcp.snd_buf, newseg) kcp.snd_nxt++ - count++ } - kcp.snd_queue = kcp.snd_queue[count:] // calculate resent resent := uint32(kcp.fastresend) @@ -774,14 +751,11 @@ func (kcp *KCP) NoDelay(interval uint32, resend int, congestionControl bool) int // WaitSnd gets how many packet is waiting to be sent func (kcp *KCP) WaitSnd() uint32 { - return uint32(len(kcp.snd_buf) + len(kcp.snd_queue)) + return uint32(len(kcp.snd_buf)) + kcp.snd_queue.Len() } func (this *KCP) ClearSendQueue() { - for _, seg := range this.snd_queue { - seg.Release() - } - this.snd_queue = nil + this.snd_queue.Clear() for _, seg := range this.snd_buf { seg.Release() diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go new file mode 100644 index 00000000..e63c30e9 --- /dev/null +++ b/transport/internet/kcp/sending.go @@ -0,0 +1,60 @@ +package kcp + +type SendingQueue struct { + start uint32 + cap uint32 + len uint32 + list []*Segment +} + +func NewSendingQueue(size uint32) *SendingQueue { + return &SendingQueue{ + start: 0, + cap: size, + list: make([]*Segment, size), + len: 0, + } +} + +func (this *SendingQueue) IsFull() bool { + return this.len == this.cap +} + +func (this *SendingQueue) IsEmpty() bool { + return this.len == 0 +} + +func (this *SendingQueue) Pop() *Segment { + if this.IsEmpty() { + return nil + } + seg := this.list[this.start] + this.list[this.start] = nil + this.len-- + this.start++ + if this.start == this.cap { + this.start = 0 + } + return seg +} + +func (this *SendingQueue) Push(seg *Segment) { + if this.IsFull() { + return + } + this.list[(this.start+this.len)%this.cap] = seg + this.len++ +} + +func (this *SendingQueue) Clear() { + for i := uint32(0); i < this.len; i++ { + this.list[(i+this.start)%this.cap].Release() + this.list[(i+this.start)%this.cap] = nil + } + this.start = 0 + this.len = 0 +} + +func (this *SendingQueue) Len() uint32 { + return this.len +} diff --git a/transport/internet/kcp/sending_test.go b/transport/internet/kcp/sending_test.go new file mode 100644 index 00000000..35224691 --- /dev/null +++ b/transport/internet/kcp/sending_test.go @@ -0,0 +1,64 @@ +package kcp_test + +import ( + "testing" + + "github.com/v2ray/v2ray-core/testing/assert" + . "github.com/v2ray/v2ray-core/transport/internet/kcp" +) + +func TestSendingQueue(t *testing.T) { + assert := assert.On(t) + + queue := NewSendingQueue(3) + + seg0 := &Segment{} + seg1 := &Segment{} + seg2 := &Segment{} + seg3 := &Segment{} + + assert.Bool(queue.IsEmpty()).IsTrue() + assert.Bool(queue.IsFull()).IsFalse() + + queue.Push(seg0) + assert.Bool(queue.IsEmpty()).IsFalse() + + queue.Push(seg1) + queue.Push(seg2) + + assert.Bool(queue.IsFull()).IsTrue() + + assert.Pointer(queue.Pop()).Equals(seg0) + + queue.Push(seg3) + assert.Bool(queue.IsFull()).IsTrue() + + assert.Pointer(queue.Pop()).Equals(seg1) + assert.Pointer(queue.Pop()).Equals(seg2) + assert.Pointer(queue.Pop()).Equals(seg3) + assert.Int(int(queue.Len())).Equals(0) +} + +func TestSendingQueueClear(t *testing.T) { + assert := assert.On(t) + + queue := NewSendingQueue(3) + + seg0 := &Segment{} + seg1 := &Segment{} + seg2 := &Segment{} + seg3 := &Segment{} + + queue.Push(seg0) + assert.Bool(queue.IsEmpty()).IsFalse() + + queue.Clear() + assert.Bool(queue.IsEmpty()).IsTrue() + + queue.Push(seg1) + queue.Push(seg2) + queue.Push(seg3) + + queue.Clear() + assert.Bool(queue.IsEmpty()).IsTrue() +}