diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go index 3552f6a6..11d67fdb 100644 --- a/transport/internet/kcp/receiving.go +++ b/transport/internet/kcp/receiving.go @@ -58,6 +58,7 @@ func (this *ReceivingWindow) Advance() { } type ReceivingQueue struct { + sync.Mutex closed bool cache *alloc.Buffer queue chan *alloc.Buffer @@ -114,6 +115,9 @@ L: } func (this *ReceivingQueue) Put(payload *alloc.Buffer) bool { + this.Lock() + defer this.Unlock() + if this.closed { payload.Release() return false @@ -133,6 +137,9 @@ func (this *ReceivingQueue) SetReadDeadline(t time.Time) error { } func (this *ReceivingQueue) Close() { + this.Lock() + defer this.Unlock() + if this.closed { return } @@ -141,6 +148,7 @@ func (this *ReceivingQueue) Close() { } type AckList struct { + sync.Mutex writer SegmentWriter timestamps []uint32 numbers []uint32 @@ -157,12 +165,18 @@ func NewACKList(writer SegmentWriter) *AckList { } func (this *AckList) Add(number uint32, timestamp uint32) { + this.Lock() + defer this.Unlock() + this.timestamps = append(this.timestamps, timestamp) this.numbers = append(this.numbers, number) this.nextFlush = append(this.nextFlush, 0) } func (this *AckList) Clear(una uint32) { + this.Lock() + defer this.Unlock() + count := 0 for i := 0; i < len(this.numbers); i++ { if this.numbers[i] >= una { @@ -181,33 +195,35 @@ func (this *AckList) Clear(una uint32) { } } -func (this *AckList) Flush(current uint32) { +func (this *AckList) Flush(current uint32, rto uint32) { seg := new(AckSegment) + this.Lock() for i := 0; i < len(this.numbers); i++ { if this.nextFlush[i] <= current { seg.Count++ seg.NumberList = append(seg.NumberList, this.numbers[i]) seg.TimestampList = append(seg.TimestampList, this.timestamps[i]) - this.nextFlush[i] = current + 100 + this.nextFlush[i] = current + rto/2 if seg.Count == 128 { break } } } + this.Unlock() if seg.Count > 0 { this.writer.Write(seg) } } type ReceivingWorker struct { - sync.Mutex - kcp *KCP - queue *ReceivingQueue - window *ReceivingWindow - acklist *AckList - updated bool - nextNumber uint32 - windowSize uint32 + kcp *KCP + queue *ReceivingQueue + window *ReceivingWindow + windowMutex sync.Mutex + acklist *AckList + updated bool + nextNumber uint32 + windowSize uint32 } func NewReceivingWorker(kcp *KCP) *ReceivingWorker { @@ -223,9 +239,6 @@ func NewReceivingWorker(kcp *KCP) *ReceivingWorker { } func (this *ReceivingWorker) ProcessSendingNext(number uint32) { - this.Lock() - defer this.Unlock() - this.acklist.Clear(number) } @@ -237,23 +250,22 @@ func (this *ReceivingWorker) ProcessSegment(seg *DataSegment) { this.ProcessSendingNext(seg.SendingNext) - this.Lock() this.acklist.Add(number, seg.Timestamp) - + this.windowMutex.Lock() idx := number - this.nextNumber if !this.window.Set(idx, seg) { seg.Release() } - this.Unlock() + this.windowMutex.Unlock() this.DumpWindow() } // @Private func (this *ReceivingWorker) DumpWindow() { - this.Lock() - defer this.Unlock() + this.windowMutex.Lock() + defer this.windowMutex.Unlock() for { seg := this.window.RemoveFirst() @@ -278,17 +290,11 @@ func (this *ReceivingWorker) Read(b []byte) (int, error) { } func (this *ReceivingWorker) SetReadDeadline(t time.Time) { - this.Lock() - defer this.Unlock() - this.queue.SetReadDeadline(t) } func (this *ReceivingWorker) Flush() { - this.Lock() - defer this.Unlock() - - this.acklist.Flush(this.kcp.current) + this.acklist.Flush(this.kcp.current, this.kcp.rx_rto) } func (this *ReceivingWorker) Write(seg ISegment) { @@ -304,9 +310,6 @@ func (this *ReceivingWorker) Write(seg ISegment) { } func (this *ReceivingWorker) CloseRead() { - this.Lock() - defer this.Unlock() - this.queue.Close() }