diff --git a/transport/internet/kcp/kcp.go b/transport/internet/kcp/kcp.go index 4e23f972..f13df238 100644 --- a/transport/internet/kcp/kcp.go +++ b/transport/internet/kcp/kcp.go @@ -81,7 +81,7 @@ func NewKCP(conv uint16, mtu uint32, sendingWindowSize uint32, receivingWindowSi kcp.rcv_buf = NewReceivingWindow(receivingWindowSize) kcp.snd_queue = NewSendingQueue(sendingQueueSize) kcp.rcv_queue = NewReceivingQueue() - kcp.acklist = new(ACKList) + kcp.acklist = NewACKList(kcp) kcp.cwnd = kcp.snd_wnd return kcp } @@ -250,9 +250,7 @@ func (kcp *KCP) HandleReceivingNext(receivingNext uint32) { } func (kcp *KCP) HandleSendingNext(sendingNext uint32) { - if kcp.acklist.Clear(sendingNext) { - kcp.receivingUpdated = true - } + kcp.acklist.Clear(sendingNext) } func (kcp *KCP) parse_data(newseg *DataSegment) { @@ -367,16 +365,9 @@ func (kcp *KCP) flush() { lost := false // flush acknowledges - //if kcp.receivingUpdated { - ackSeg := kcp.acklist.AsSegment() - if ackSeg != nil { - ackSeg.Conv = kcp.conv - ackSeg.ReceivingWindow = uint32(kcp.rcv_nxt + kcp.rcv_wnd) - ackSeg.ReceivingNext = kcp.rcv_nxt - kcp.output.Write(ackSeg) + if kcp.acklist.Flush() { kcp.receivingUpdated = false } - //} // calculate window size cwnd := kcp.snd_una + kcp.snd_wnd diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go index 22c3c578..342712ea 100644 --- a/transport/internet/kcp/receiving.go +++ b/transport/internet/kcp/receiving.go @@ -149,22 +149,35 @@ func (this *ReceivingQueue) Close() { } type ACKList struct { + kcp *KCP timestamps []uint32 numbers []uint32 + nextFlush []uint32 +} + +func NewACKList(kcp *KCP) *ACKList { + return &ACKList{ + kcp: kcp, + timestamps: make([]uint32, 0, 32), + numbers: make([]uint32, 0, 32), + nextFlush: make([]uint32, 0, 32), + } } func (this *ACKList) Add(number uint32, timestamp uint32) { this.timestamps = append(this.timestamps, timestamp) this.numbers = append(this.numbers, number) + this.nextFlush = append(this.nextFlush, 0) } -func (this *ACKList) Clear(una uint32) bool { +func (this *ACKList) Clear(una uint32) { count := 0 for i := 0; i < len(this.numbers); i++ { if this.numbers[i] >= una { if i != count { this.numbers[count] = this.numbers[i] this.timestamps[count] = this.timestamps[i] + this.nextFlush[count] = this.nextFlush[i] } count++ } @@ -172,26 +185,34 @@ func (this *ACKList) Clear(una uint32) bool { if count < len(this.numbers) { this.numbers = this.numbers[:count] this.timestamps = this.timestamps[:count] - return true + this.nextFlush = this.nextFlush[:count] } - return false } -func (this *ACKList) AsSegment() *ACKSegment { - count := len(this.numbers) - if count == 0 { - return nil +func (this *ACKList) Flush() bool { + seg := &ACKSegment{ + Conv: this.kcp.conv, + ReceivingNext: this.kcp.rcv_nxt, + ReceivingWindow: this.kcp.rcv_nxt + this.kcp.rcv_wnd, } - - if count > 128 { - count = 128 + if this.kcp.state == StateReadyToClose { + seg.Opt = SegmentOptionClose } - seg := &ACKSegment{ - Count: byte(count), - NumberList: this.numbers[:count], - TimestampList: this.timestamps[:count], + current := this.kcp.current + for i := 0; i < len(this.numbers); i++ { + if this.nextFlush[i] <= current { + seg.Count++ + seg.NumberList = append(seg.NumberList, this.numbers[i]) + seg.TimestampList = append(seg.TimestampList, this.timestamps[i]) + this.nextFlush[i] = current + 50 + if seg.Count == 128 { + break + } + } } - //this.numbers = nil - //this.timestamps = nil - return seg + if seg.Count > 0 { + this.kcp.output.Write(seg) + return true + } + return false }