refine connection.read

pull/215/head
v2ray 2016-07-06 17:34:38 +02:00
parent a615afc906
commit 56ce062154
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
3 changed files with 102 additions and 136 deletions

View File

@ -99,9 +99,11 @@ func (this *RountTripInfo) SmoothedTime() uint32 {
type Connection struct { type Connection struct {
block Authenticator block Authenticator
local, remote net.Addr local, remote net.Addr
rd time.Time
wd time.Time // write deadline wd time.Time // write deadline
writer io.WriteCloser writer io.WriteCloser
since int64 since int64
dataInputCond *sync.Cond
conv uint16 conv uint16
state State state State
@ -133,6 +135,7 @@ func NewConnection(conv uint16, writerCloser io.WriteCloser, local *net.UDPAddr,
conn.block = block conn.block = block
conn.writer = writerCloser conn.writer = writerCloser
conn.since = nowMillisec() conn.since = nowMillisec()
conn.dataInputCond = sync.NewCond(new(sync.Mutex))
authWriter := &AuthenticationWriter{ authWriter := &AuthenticationWriter{
Authenticator: block, Authenticator: block,
@ -167,10 +170,28 @@ func (this *Connection) Read(b []byte) (int, error) {
return 0, io.EOF return 0, io.EOF
} }
if this.State() == StateTerminating || this.State() == StateTerminated { for {
return 0, io.EOF if this.State() == StateReadyToClose || this.State() == StateTerminating || this.State() == StateTerminated {
return 0, io.EOF
}
nBytes := this.receivingWorker.Read(b)
if nBytes > 0 {
return nBytes, nil
}
var timer *time.Timer
if !this.rd.IsZero() && this.rd.Before(time.Now()) {
timer = time.AfterFunc(this.rd.Sub(time.Now()), this.dataInputCond.Signal)
}
this.dataInputCond.L.Lock()
this.dataInputCond.Wait()
this.dataInputCond.L.Unlock()
if timer != nil {
timer.Stop()
}
if !this.rd.IsZero() && this.rd.Before(time.Now()) {
return 0, errTimeout
}
} }
return this.receivingWorker.Read(b)
} }
// Write implements the Conn Write method. // Write implements the Conn Write method.
@ -226,6 +247,8 @@ func (this *Connection) Close() error {
return errClosedConnection return errClosedConnection
} }
this.dataInputCond.Broadcast()
state := this.State() state := this.State()
if state == StateReadyToClose || if state == StateReadyToClose ||
state == StateTerminating || state == StateTerminating ||
@ -276,7 +299,7 @@ func (this *Connection) SetReadDeadline(t time.Time) error {
if this == nil || this.State() != StateActive { if this == nil || this.State() != StateActive {
return errClosedConnection return errClosedConnection
} }
this.receivingWorker.SetReadDeadline(t) this.rd = t
return nil return nil
} }
@ -371,6 +394,7 @@ func (this *Connection) Input(data []byte) int {
this.HandleOption(seg.Opt) this.HandleOption(seg.Opt)
this.receivingWorker.ProcessSegment(seg) this.receivingWorker.ProcessSegment(seg)
atomic.StoreUint32(&this.lastPayloadTime, current) atomic.StoreUint32(&this.lastPayloadTime, current)
this.dataInputCond.Signal()
case *AckSegment: case *AckSegment:
this.HandleOption(seg.Opt) this.HandleOption(seg.Opt)
this.sendingWorker.ProcessSegment(current, seg) this.sendingWorker.ProcessSegment(current, seg)

View File

@ -1,9 +1,7 @@
package kcp package kcp
import ( import (
"io"
"sync" "sync"
"time"
"github.com/v2ray/v2ray-core/common/alloc" "github.com/v2ray/v2ray-core/common/alloc"
) )
@ -58,101 +56,68 @@ func (this *ReceivingWindow) Advance() {
} }
type ReceivingQueue struct { type ReceivingQueue struct {
sync.Mutex start uint32
closed bool cap uint32
cache *alloc.Buffer len uint32
queue chan *alloc.Buffer data []*alloc.Buffer
timeout time.Time
} }
func NewReceivingQueue(size uint32) *ReceivingQueue { func NewReceivingQueue(size uint32) *ReceivingQueue {
return &ReceivingQueue{ return &ReceivingQueue{
queue: make(chan *alloc.Buffer, size), cap: size,
data: make([]*alloc.Buffer, size),
} }
} }
func (this *ReceivingQueue) Read(buf []byte) (int, error) { func (this *ReceivingQueue) IsEmpty() bool {
if this.closed { return this.len == 0
return 0, io.EOF }
func (this *ReceivingQueue) IsFull() bool {
return this.len == this.cap
}
func (this *ReceivingQueue) Read(buf []byte) int {
if this.IsEmpty() {
return 0
} }
if this.cache.Len() > 0 { totalBytes := 0
nBytes, err := this.cache.Read(buf) lenBuf := len(buf)
if this.cache.IsEmpty() { for !this.IsEmpty() && totalBytes < lenBuf {
this.cache.Release() payload := this.data[this.start]
this.cache = nil nBytes, _ := payload.Read(buf)
} buf = buf[nBytes:]
return nBytes, err totalBytes += nBytes
} if payload.IsEmpty() {
payload.Release()
var totalBytes int this.data[this.start] = nil
this.start++
L: if this.start == this.cap {
for totalBytes < len(buf) { this.start = 0
timeToSleep := time.Millisecond
select {
case payload, open := <-this.queue:
if !open {
return totalBytes, io.EOF
} }
nBytes, err := payload.Read(buf) this.len--
totalBytes += nBytes if this.len == 0 {
if err != nil { this.start = 0
return totalBytes, err
}
if !payload.IsEmpty() {
this.cache = payload
}
buf = buf[nBytes:]
case <-time.After(timeToSleep):
if totalBytes > 0 {
break L
}
if !this.timeout.IsZero() && this.timeout.Before(time.Now()) {
return totalBytes, errTimeout
}
timeToSleep += 500 * time.Millisecond
if timeToSleep > 5*time.Second {
timeToSleep = 5 * time.Second
} }
} }
} }
return totalBytes
return totalBytes, nil
} }
func (this *ReceivingQueue) Put(payload *alloc.Buffer) bool { func (this *ReceivingQueue) Put(payload *alloc.Buffer) {
if this.closed { this.data[(this.start+this.len)%this.cap] = payload
payload.Release() this.len++
return false
}
select {
case this.queue <- payload:
return true
default:
return false
}
}
func (this *ReceivingQueue) SetReadDeadline(t time.Time) error {
this.timeout = t
return nil
} }
func (this *ReceivingQueue) Close() { func (this *ReceivingQueue) Close() {
this.Lock() for i := uint32(0); i < this.len; i++ {
defer this.Unlock() this.data[(this.start+i)%this.cap].Release()
this.data[(this.start+i)%this.cap] = nil
if this.closed {
return
} }
this.closed = true
close(this.queue)
} }
type AckList struct { type AckList struct {
sync.Mutex
writer SegmentWriter writer SegmentWriter
timestamps []uint32 timestamps []uint32
numbers []uint32 numbers []uint32
@ -169,18 +134,12 @@ func NewAckList(writer SegmentWriter) *AckList {
} }
func (this *AckList) Add(number uint32, timestamp uint32) { func (this *AckList) Add(number uint32, timestamp uint32) {
this.Lock()
defer this.Unlock()
this.timestamps = append(this.timestamps, timestamp) this.timestamps = append(this.timestamps, timestamp)
this.numbers = append(this.numbers, number) this.numbers = append(this.numbers, number)
this.nextFlush = append(this.nextFlush, 0) this.nextFlush = append(this.nextFlush, 0)
} }
func (this *AckList) Clear(una uint32) { func (this *AckList) Clear(una uint32) {
this.Lock()
defer this.Unlock()
count := 0 count := 0
for i := 0; i < len(this.numbers); i++ { for i := 0; i < len(this.numbers); i++ {
if this.numbers[i] >= una { if this.numbers[i] >= una {
@ -201,14 +160,12 @@ func (this *AckList) Clear(una uint32) {
func (this *AckList) Flush(current uint32, rto uint32) { func (this *AckList) Flush(current uint32, rto uint32) {
seg := NewAckSegment() seg := NewAckSegment()
this.Lock()
for i := 0; i < len(this.numbers) && !seg.IsFull(); i++ { for i := 0; i < len(this.numbers) && !seg.IsFull(); i++ {
if this.nextFlush[i] <= current { if this.nextFlush[i] <= current {
seg.PutNumber(this.numbers[i], this.timestamps[i]) seg.PutNumber(this.numbers[i], this.timestamps[i])
this.nextFlush[i] = current + rto/2 this.nextFlush[i] = current + rto/2
} }
} }
this.Unlock()
if seg.Count > 0 { if seg.Count > 0 {
this.writer.Write(seg) this.writer.Write(seg)
seg.Release() seg.Release()
@ -216,14 +173,14 @@ func (this *AckList) Flush(current uint32, rto uint32) {
} }
type ReceivingWorker struct { type ReceivingWorker struct {
conn *Connection sync.Mutex
queue *ReceivingQueue conn *Connection
window *ReceivingWindow queue *ReceivingQueue
windowMutex sync.Mutex window *ReceivingWindow
acklist *AckList acklist *AckList
updated bool updated bool
nextNumber uint32 nextNumber uint32
windowSize uint32 windowSize uint32
} }
func NewReceivingWorker(kcp *Connection) *ReceivingWorker { func NewReceivingWorker(kcp *Connection) *ReceivingWorker {
@ -239,35 +196,35 @@ func NewReceivingWorker(kcp *Connection) *ReceivingWorker {
} }
func (this *ReceivingWorker) ProcessSendingNext(number uint32) { func (this *ReceivingWorker) ProcessSendingNext(number uint32) {
this.Lock()
defer this.Unlock()
this.acklist.Clear(number) this.acklist.Clear(number)
} }
func (this *ReceivingWorker) ProcessSegment(seg *DataSegment) { func (this *ReceivingWorker) ProcessSegment(seg *DataSegment) {
this.Lock()
defer this.Unlock()
number := seg.Number number := seg.Number
idx := number - this.nextNumber idx := number - this.nextNumber
if idx >= this.windowSize { if idx >= this.windowSize {
return return
} }
this.ProcessSendingNext(seg.SendingNext) this.acklist.Clear(seg.SendingNext)
this.acklist.Add(number, seg.Timestamp) this.acklist.Add(number, seg.Timestamp)
this.windowMutex.Lock()
defer this.windowMutex.Unlock()
if !this.window.Set(idx, seg) { if !this.window.Set(idx, seg) {
seg.Release() seg.Release()
} }
for { for !this.queue.IsFull() {
seg := this.window.RemoveFirst() seg := this.window.RemoveFirst()
if seg == nil { if seg == nil {
break break
} }
if !this.queue.Put(seg.Data) { this.queue.Put(seg.Data)
this.window.Set(0, seg)
break
}
seg.Data = nil seg.Data = nil
seg.Release() seg.Release()
this.window.Advance() this.window.Advance()
@ -276,15 +233,17 @@ func (this *ReceivingWorker) ProcessSegment(seg *DataSegment) {
} }
} }
func (this *ReceivingWorker) Read(b []byte) (int, error) { func (this *ReceivingWorker) Read(b []byte) int {
this.Lock()
defer this.Unlock()
return this.queue.Read(b) return this.queue.Read(b)
} }
func (this *ReceivingWorker) SetReadDeadline(t time.Time) {
this.queue.SetReadDeadline(t)
}
func (this *ReceivingWorker) Flush(current uint32) { func (this *ReceivingWorker) Flush(current uint32) {
this.Lock()
defer this.Unlock()
this.acklist.Flush(current, this.conn.roundTrip.Timeout()) this.acklist.Flush(current, this.conn.roundTrip.Timeout())
} }
@ -301,6 +260,9 @@ func (this *ReceivingWorker) Write(seg Segment) {
} }
func (this *ReceivingWorker) CloseRead() { func (this *ReceivingWorker) CloseRead() {
this.Lock()
defer this.Unlock()
this.queue.Close() this.queue.Close()
} }

View File

@ -1,9 +1,7 @@
package kcp_test package kcp_test
import ( import (
"io"
"testing" "testing"
"time"
"github.com/v2ray/v2ray-core/common/alloc" "github.com/v2ray/v2ray-core/common/alloc"
"github.com/v2ray/v2ray-core/testing/assert" "github.com/v2ray/v2ray-core/testing/assert"
@ -42,35 +40,17 @@ func TestRecivingQueue(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
queue := NewReceivingQueue(2) queue := NewReceivingQueue(2)
assert.Bool(queue.Put(alloc.NewSmallBuffer().Clear().AppendString("abcd"))).IsTrue() queue.Put(alloc.NewSmallBuffer().Clear().AppendString("abcd"))
assert.Bool(queue.Put(alloc.NewSmallBuffer().Clear().AppendString("efg"))).IsTrue() queue.Put(alloc.NewSmallBuffer().Clear().AppendString("efg"))
assert.Bool(queue.Put(alloc.NewSmallBuffer().Clear().AppendString("more content"))).IsFalse() assert.Bool(queue.IsFull()).IsTrue()
b := make([]byte, 1024) b := make([]byte, 1024)
nBytes, err := queue.Read(b) nBytes := queue.Read(b)
assert.Error(err).IsNil()
assert.Int(nBytes).Equals(7) assert.Int(nBytes).Equals(7)
assert.String(string(b[:nBytes])).Equals("abcdefg") assert.String(string(b[:nBytes])).Equals("abcdefg")
assert.Bool(queue.Put(alloc.NewSmallBuffer().Clear().AppendString("1"))).IsTrue() queue.Put(alloc.NewSmallBuffer().Clear().AppendString("1"))
queue.Close() queue.Close()
nBytes, err = queue.Read(b) nBytes = queue.Read(b)
assert.Error(err).Equals(io.EOF) assert.Int(nBytes).Equals(0)
}
func TestRecivingQueueTimeout(t *testing.T) {
assert := assert.On(t)
queue := NewReceivingQueue(2)
assert.Bool(queue.Put(alloc.NewSmallBuffer().Clear().AppendString("abcd"))).IsTrue()
queue.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 1024)
nBytes, err := queue.Read(b)
assert.Error(err).IsNil()
assert.Int(nBytes).Equals(4)
assert.String(string(b[:nBytes])).Equals("abcd")
nBytes, err = queue.Read(b)
assert.Error(err).IsNotNil()
} }