diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 280b20a3..e44254cd 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -188,7 +188,7 @@ type Connection struct { receivingWorker *ReceivingWorker sendingWorker *SendingWorker - output *BufferedSegmentWriter + output SegmentWriter dataUpdater *Updater pingUpdater *Updater @@ -208,7 +208,7 @@ func NewConnection(conv uint16, sysConn SystemConnection, recycler internal.Conn dataInput: make(chan bool, 1), dataOutput: make(chan bool, 1), Config: config, - output: NewSegmentWriter(sysConn, config.GetMtu().GetValue()-uint32(sysConn.Overhead())), + output: NewSegmentWriter(sysConn), mss: config.GetMtu().GetValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead, roundTrip: &RoundTripInfo{ rto: 100, @@ -542,7 +542,6 @@ func (v *Connection) flush() { if v.State() == StateTerminating { log.Debug("KCP|Connection: #", v.conv, " sending terminating cmd.") v.Ping(current, CommandTerminate) - v.output.Flush() if current-atomic.LoadUint32(&v.stateBeginTime) > 8000 { v.SetState(StateTerminated) @@ -564,9 +563,6 @@ func (v *Connection) flush() { if current-atomic.LoadUint32(&v.lastPingTime) >= 3000 { v.Ping(current, CommandPing) } - - // flash remain segments - v.output.Flush() } func (v *Connection) State() State { diff --git a/transport/internet/kcp/output.go b/transport/internet/kcp/output.go index 47bd6ee7..98b054f1 100644 --- a/transport/internet/kcp/output.go +++ b/transport/internet/kcp/output.go @@ -8,48 +8,28 @@ import ( ) type SegmentWriter interface { - Write(seg Segment) + Write(seg Segment) error } -type BufferedSegmentWriter struct { +type SimpleSegmentWriter struct { sync.Mutex - mtu uint32 buffer *buf.Buffer writer io.Writer } -func NewSegmentWriter(writer io.Writer, mtu uint32) *BufferedSegmentWriter { - return &BufferedSegmentWriter{ - mtu: mtu, +func NewSegmentWriter(writer io.Writer) SegmentWriter { + return &SimpleSegmentWriter{ writer: writer, buffer: buf.NewSmall(), } } -func (v *BufferedSegmentWriter) Write(seg Segment) { +func (v *SimpleSegmentWriter) Write(seg Segment) error { v.Lock() defer v.Unlock() - nBytes := seg.ByteSize() - if uint32(v.buffer.Len()+nBytes) > v.mtu { - v.FlushWithoutLock() - } - v.buffer.AppendSupplier(seg.Bytes()) -} - -func (v *BufferedSegmentWriter) FlushWithoutLock() { - v.writer.Write(v.buffer.Bytes()) + _, err := v.writer.Write(v.buffer.Bytes()) v.buffer.Clear() -} - -func (v *BufferedSegmentWriter) Flush() { - v.Lock() - defer v.Unlock() - - if v.buffer.IsEmpty() { - return - } - - v.FlushWithoutLock() + return err } diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go index c924f470..70838807 100644 --- a/transport/internet/kcp/receiving.go +++ b/transport/internet/kcp/receiving.go @@ -235,7 +235,7 @@ func (v *ReceivingWorker) Flush(current uint32) { v.acklist.Flush(current, v.conn.roundTrip.Timeout()) } -func (v *ReceivingWorker) Write(seg Segment) { +func (v *ReceivingWorker) Write(seg Segment) error { ackSeg := seg.(*AckSegment) ackSeg.Conv = v.conn.conv ackSeg.ReceivingNext = v.nextNumber @@ -243,7 +243,7 @@ func (v *ReceivingWorker) Write(seg Segment) { if v.conn.state == StateReadyToClose { ackSeg.Option = SegmentOptionClose } - v.conn.output.Write(ackSeg) + return v.conn.output.Write(ackSeg) } func (v *ReceivingWorker) CloseRead() { diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index b058ae09..8336f7a5 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -305,7 +305,7 @@ func (v *SendingWorker) Push(b []byte) int { } // Private: Visible for testing. -func (v *SendingWorker) Write(seg Segment) { +func (v *SendingWorker) Write(seg Segment) error { dataSeg := seg.(*DataSegment) dataSeg.Conv = v.conn.conv @@ -315,7 +315,7 @@ func (v *SendingWorker) Write(seg Segment) { dataSeg.Option = SegmentOptionClose } - v.conn.output.Write(dataSeg) + return v.conn.output.Write(dataSeg) } func (v *SendingWorker) OnPacketLoss(lossRate uint32) {