From a6c0ef11ba2026e9f8bfad618ebc9c7e8287ffa1 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Wed, 17 Jan 2018 16:18:38 +0100 Subject: [PATCH] check connection state for every write operation --- transport/internet/kcp/connection.go | 49 +++++++++++++++++++--------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index d7e69f9d..b7b0652f 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -341,21 +341,30 @@ func (c *Connection) Write(b []byte) (int, error) { totalWritten := 0 for { - if c == nil || c.State() != StateActive { - return totalWritten, io.ErrClosedPipe - } + dataWritten := false + for { + if c == nil || c.State() != StateActive { + return totalWritten, io.ErrClosedPipe + } + if !c.sendingWorker.Push(func(bb []byte) (int, error) { + n := copy(bb[:c.mss], b[totalWritten:]) + totalWritten += n + return n, nil + }) { + break + } + + dataWritten = true - for c.sendingWorker.Push(func(bb []byte) (int, error) { - n := copy(bb[:c.mss], b[totalWritten:]) - totalWritten += n - return n, nil - }) { - c.dataUpdater.WakeUp() if totalWritten == len(b) { return totalWritten, nil } } + if dataWritten { + c.dataUpdater.WakeUp() + } + if err := c.waitForDataOutput(); err != nil { return totalWritten, err } @@ -367,19 +376,27 @@ func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { defer mb.Release() for { - if c == nil || c.State() != StateActive { - return io.ErrClosedPipe - } + dataWritten := false + for { + if c == nil || c.State() != StateActive { + return io.ErrClosedPipe + } - for c.sendingWorker.Push(func(bb []byte) (int, error) { - return mb.Read(bb[:c.mss]) - }) { - c.dataUpdater.WakeUp() + if !c.sendingWorker.Push(func(bb []byte) (int, error) { + return mb.Read(bb[:c.mss]) + }) { + break + } + dataWritten = true if mb.IsEmpty() { return nil } } + if dataWritten { + c.dataUpdater.WakeUp() + } + if err := c.waitForDataOutput(); err != nil { return err }