diff --git a/common/buf/multi_buffer_test.go b/common/buf/multi_buffer_test.go index 9b9e0c93..bf7059d4 100644 --- a/common/buf/multi_buffer_test.go +++ b/common/buf/multi_buffer_test.go @@ -23,3 +23,13 @@ func TestMultiBufferRead(t *testing.T) { assert.Int(nBytes).Equals(4) assert.Bytes(bs[:nBytes]).Equals([]byte("abcd")) } + +func TestMultiBufferAppend(t *testing.T) { + assert := assert.On(t) + + var mb MultiBuffer + b := New() + b.AppendBytes('a', 'b') + mb.Append(b) + assert.Int(mb.Len()).Equals(2) +} diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 43f0a9c7..a395a162 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -169,6 +169,7 @@ type SystemConnection interface { } var ( + _ buf.MultiBufferReader = (*Connection)(nil) _ buf.MultiBufferWriter = (*Connection)(nil) ) @@ -264,6 +265,43 @@ func (v *Connection) OnDataOutput() { } } +// ReadMultiBuffer implements buf.MultiBufferReader. +func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { + if v == nil { + return nil, io.EOF + } + + for { + if v.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { + return nil, io.EOF + } + mb := v.receivingWorker.ReadMultiBuffer() + if !mb.IsEmpty() { + return mb, nil + } + + if v.State() == StatePeerTerminating { + return nil, io.EOF + } + + duration := time.Minute + if !v.rd.IsZero() { + duration = v.rd.Sub(time.Now()) + if duration < 0 { + return nil, ErrIOTimeout + } + } + + select { + case <-v.dataInput: + case <-time.After(duration): + if !v.rd.IsZero() && v.rd.Before(time.Now()) { + return nil, ErrIOTimeout + } + } + } +} + // Read implements the Conn Read method. func (v *Connection) Read(b []byte) (int, error) { if v == nil { diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go index b5a0cc80..7c10dad5 100644 --- a/transport/internet/kcp/receiving.go +++ b/transport/internet/kcp/receiving.go @@ -149,7 +149,7 @@ func (v *AckList) Flush(current uint32, rto uint32) { type ReceivingWorker struct { sync.RWMutex conn *Connection - leftOver *buf.Buffer + leftOver buf.MultiBuffer window *ReceivingWindow acklist *AckList nextNumber uint32 @@ -196,42 +196,39 @@ func (v *ReceivingWorker) ProcessSegment(seg *DataSegment) { } } -func (v *ReceivingWorker) Read(b []byte) int { - v.Lock() - defer v.Unlock() - - total := 0 +func (v *ReceivingWorker) ReadMultiBuffer() buf.MultiBuffer { if v.leftOver != nil { - nBytes := copy(b, v.leftOver.Bytes()) - if nBytes < v.leftOver.Len() { - v.leftOver.SliceFrom(nBytes) - return nBytes - } - v.leftOver.Release() + mb := v.leftOver v.leftOver = nil - total += nBytes + return mb } - for total < len(b) { + mb := buf.NewMultiBuffer() + + v.Lock() + defer v.Unlock() + for { seg := v.window.RemoveFirst() if seg == nil { break } v.window.Advance() v.nextNumber++ - - nBytes := copy(b[total:], seg.Data.Bytes()) - total += nBytes - if nBytes < seg.Data.Len() { - seg.Data.SliceFrom(nBytes) - v.leftOver = seg.Data - seg.Data = nil - seg.Release() - break - } + mb.Append(seg.Data) + seg.Data = nil seg.Release() } - return total + + return mb +} + +func (v *ReceivingWorker) Read(b []byte) int { + mb := v.ReadMultiBuffer() + nBytes, _ := mb.Read(b) + if !mb.IsEmpty() { + v.leftOver = mb + } + return nBytes } func (w *ReceivingWorker) IsDataAvailable() bool {