diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index f7f0be55..4cf82736 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "time" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/signal" "v2ray.com/core/common/signal/semaphore" @@ -342,43 +343,12 @@ func (c *Connection) waitForDataOutput() error { // Write implements io.Writer. func (c *Connection) Write(b []byte) (int, error) { - updatePending := false - defer func() { - if updatePending { - c.dataUpdater.WakeUp() - } - }() - - for { - totalWritten := 0 - for { - if c == nil || c.State() != StateActive { - return totalWritten, io.ErrClosedPipe - } - if !c.sendingWorker.Push(func(bb []byte) (int, error) { - n := copy(bb[:c.mss], b[totalWritten:]) - totalWritten += n - return n, nil - }) { - break - } - - updatePending = true - - if totalWritten == len(b) { - return totalWritten, nil - } - } - - if updatePending { - c.dataUpdater.WakeUp() - updatePending = false - } - - if err := c.waitForDataOutput(); err != nil { - return totalWritten, err - } + var mb buf.MultiBuffer + common.Must2(mb.Write(b)) + if err := c.WriteMultiBuffer(mb); err != nil { + return 0, err } + return len(b), nil } // WriteMultiBuffer implements buf.Writer. @@ -392,19 +362,13 @@ func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { } }() - f := func(x *buf.MultiBuffer) buf.Supplier { - return func(bb []byte) (int, error) { - return x.Read(bb[:c.mss]) - } - }(&mb) - for { for { if c == nil || c.State() != StateActive { return io.ErrClosedPipe } - if !c.sendingWorker.Push(f) { + if !c.sendingWorker.Push(&mb) { break } updatePending = true diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index 4d5d47ce..ca1fb6bb 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -43,12 +43,12 @@ func (sw *SendingWindow) IsEmpty() bool { return sw.cache.Len() == 0 } -func (sw *SendingWindow) Push(number uint32) *buf.Buffer { +func (sw *SendingWindow) Push(number uint32, b *buf.Buffer) { seg := NewDataSegment() seg.Number = number + seg.payload = b sw.cache.PushBack(seg) - return seg.Data() } func (sw *SendingWindow) FirstNumber() uint32 { @@ -261,7 +261,7 @@ func (w *SendingWorker) ProcessSegment(current uint32, seg *AckSegment, rto uint } } -func (w *SendingWorker) Push(f buf.Supplier) bool { +func (w *SendingWorker) Push(mb *buf.MultiBuffer) bool { w.Lock() defer w.Unlock() @@ -273,9 +273,12 @@ func (w *SendingWorker) Push(f buf.Supplier) bool { return false } - b := w.window.Push(w.nextNumber) + b := buf.New() + common.Must(b.Reset(func(v []byte) (int, error) { + return mb.Read(v[:w.conn.mss]) + })) + w.window.Push(w.nextNumber, b) w.nextNumber++ - common.Must(b.Reset(f)) return true }