diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 6ea411ac..860af5eb 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -109,7 +109,6 @@ type Connection struct { state State stateBeginTime uint32 lastIncomingTime uint32 - sendingUpdated bool lastPingTime uint32 mss uint32 @@ -463,13 +462,13 @@ func (this *Connection) flush() { } this.output.Write(seg) this.lastPingTime = current - this.sendingUpdated = false + this.sendingWorker.MarkPingNecessary(false) + this.receivingWorker.MarkPingNecessary(false) seg.Release() } // flash remain segments this.output.Flush() - } func (this *Connection) State() State { diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go index 33931803..0a9a905b 100644 --- a/transport/internet/kcp/receiving.go +++ b/transport/internet/kcp/receiving.go @@ -173,7 +173,7 @@ func (this *AckList) Flush(current uint32, rto uint32) { } type ReceivingWorker struct { - sync.Mutex + sync.RWMutex conn *Connection queue *ReceivingQueue window *ReceivingWindow @@ -267,5 +267,13 @@ func (this *ReceivingWorker) CloseRead() { } func (this *ReceivingWorker) PingNecessary() bool { + this.RLock() + defer this.RUnlock() return this.updated } + +func (this *ReceivingWorker) MarkPingNecessary(b bool) { + this.Lock() + defer this.Unlock() + this.updated = b +} diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index 0ea7f202..03fe998b 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -227,7 +227,7 @@ func (this *SendingQueue) Len() uint32 { } type SendingWorker struct { - sync.Mutex + sync.RWMutex conn *Connection window *SendingWindow queue *SendingQueue @@ -347,9 +347,19 @@ func (this *SendingWorker) Write(seg Segment) { } func (this *SendingWorker) PingNecessary() bool { + this.RLock() + defer this.RUnlock() + return this.updated } +func (this *SendingWorker) MarkPingNecessary(b bool) { + this.Lock() + defer this.Unlock() + + this.updated = b +} + func (this *SendingWorker) OnPacketLoss(lossRate uint32) { if !effectiveConfig.Congestion || this.conn.roundTrip.Timeout() == 0 { return