diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 12383823..bf718c5d 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -88,7 +88,7 @@ func (this *Connection) Write(b []byte) (int, error) { this.RUnlock() this.kcpAccess.Lock() - nBytes := this.kcp.Send(b[totalWritten:]) + nBytes := this.kcp.sendingWorker.Push(b[totalWritten:]) if nBytes > 0 { totalWritten += nBytes if totalWritten == len(b) { diff --git a/transport/internet/kcp/kcp.go b/transport/internet/kcp/kcp.go index 2e49b3c0..39985a97 100644 --- a/transport/internet/kcp/kcp.go +++ b/transport/internet/kcp/kcp.go @@ -6,7 +6,6 @@ package kcp import ( - "github.com/v2ray/v2ray-core/common/alloc" "github.com/v2ray/v2ray-core/common/log" ) @@ -35,14 +34,11 @@ type KCP struct { lastPingTime uint32 mss uint32 - snd_una, snd_nxt uint32 rx_rttvar, rx_srtt, rx_rto uint32 - snd_wnd, rmt_wnd, cwnd uint32 current, interval uint32 - snd_queue *SendingQueue - snd_buf *SendingWindow receivingWorker *ReceivingWorker + sendingWorker *SendingWorker fastresend uint32 congestionControl bool @@ -55,18 +51,14 @@ func NewKCP(conv uint16, output *AuthenticationWriter) *KCP { log.Debug("KCP|Core: creating KCP ", conv) kcp := new(KCP) kcp.conv = conv - kcp.snd_wnd = effectiveConfig.GetSendingWindowSize() - kcp.rmt_wnd = 32 kcp.mss = output.Mtu() - DataSegmentOverhead kcp.rx_rto = 100 kcp.interval = effectiveConfig.Tti kcp.output = NewSegmentWriter(output) - kcp.snd_queue = NewSendingQueue(effectiveConfig.GetSendingQueueSize()) - kcp.snd_buf = NewSendingWindow(kcp, effectiveConfig.GetSendingWindowSize()) - kcp.cwnd = kcp.snd_wnd kcp.receivingWorker = NewReceivingWorker(kcp) kcp.fastresend = 2 kcp.congestionControl = effectiveConfig.Congestion + kcp.sendingWorker = NewSendingWorker(kcp) return kcp } @@ -78,11 +70,13 @@ func (kcp *KCP) SetState(state State) { case StateReadyToClose: kcp.receivingWorker.CloseRead() case StatePeerClosed: - kcp.ClearSendQueue() + kcp.sendingWorker.CloseWrite() case StateTerminating: kcp.receivingWorker.CloseRead() + kcp.sendingWorker.CloseWrite() case StateTerminated: kcp.receivingWorker.CloseRead() + kcp.sendingWorker.CloseWrite() } } @@ -110,26 +104,6 @@ func (kcp *KCP) OnClose() { } } -// Send is user/upper level send, returns below zero for error -func (kcp *KCP) Send(buffer []byte) int { - nBytes := 0 - for len(buffer) > 0 && !kcp.snd_queue.IsFull() { - var size int - if len(buffer) > int(kcp.mss) { - size = int(kcp.mss) - } else { - size = len(buffer) - } - seg := &DataSegment{ - Data: alloc.NewSmallBuffer().Clear().Append(buffer[:size]), - } - kcp.snd_queue.Push(seg) - buffer = buffer[size:] - nBytes += size - } - return nBytes -} - // https://tools.ietf.org/html/rfc6298 func (kcp *KCP) update_ack(rtt int32) { if kcp.rx_srtt == 0 { @@ -159,46 +133,11 @@ func (kcp *KCP) update_ack(rtt int32) { kcp.rx_rto = rto * 3 / 2 } -func (kcp *KCP) shrink_buf() { - prevUna := kcp.snd_una - if kcp.snd_buf.Len() > 0 { - seg := kcp.snd_buf.First() - kcp.snd_una = seg.Number - } else { - kcp.snd_una = kcp.snd_nxt - } - if kcp.snd_una != prevUna { - kcp.sendingUpdated = true - } -} - -func (kcp *KCP) parse_ack(sn uint32) { - if _itimediff(sn, kcp.snd_una) < 0 || _itimediff(sn, kcp.snd_nxt) >= 0 { - return - } - - kcp.snd_buf.Remove(sn - kcp.snd_una) -} - -func (kcp *KCP) parse_fastack(sn uint32) { - if _itimediff(sn, kcp.snd_una) < 0 || _itimediff(sn, kcp.snd_nxt) >= 0 { - return - } - - kcp.snd_buf.HandleFastAck(sn) -} - -func (kcp *KCP) HandleReceivingNext(receivingNext uint32) { - kcp.snd_buf.Clear(receivingNext) -} - // Input when you received a low level packet (eg. UDP packet), call it func (kcp *KCP) Input(data []byte) int { kcp.lastIncomingTime = kcp.current var seg ISegment - var maxack uint32 - var flag int for { seg, data = ReadSegment(data) if seg == nil { @@ -212,26 +151,7 @@ func (kcp *KCP) Input(data []byte) int { kcp.lastPayloadTime = kcp.current case *AckSegment: kcp.HandleOption(seg.Opt) - if kcp.rmt_wnd < seg.ReceivingWindow { - kcp.rmt_wnd = seg.ReceivingWindow - } - kcp.HandleReceivingNext(seg.ReceivingNext) - kcp.shrink_buf() - for i := 0; i < int(seg.Count); i++ { - ts := seg.TimestampList[i] - sn := seg.NumberList[i] - if _itimediff(kcp.current, ts) >= 0 { - kcp.update_ack(_itimediff(kcp.current, ts)) - } - kcp.parse_ack(sn) - kcp.shrink_buf() - if flag == 0 { - flag = 1 - maxack = sn - } else if _itimediff(sn, maxack) > 0 { - maxack = sn - } - } + kcp.sendingWorker.ProcessAckSegment(seg) kcp.lastPayloadTime = kcp.current case *CmdOnlySegment: kcp.HandleOption(seg.Opt) @@ -244,17 +164,12 @@ func (kcp *KCP) Input(data []byte) int { kcp.SetState(StateTerminated) } } - kcp.HandleReceivingNext(seg.ReceivinNext) + kcp.sendingWorker.ProcessReceivingNext(seg.ReceivinNext) kcp.receivingWorker.ProcessSendingNext(seg.SendingNext) - kcp.shrink_buf() default: } } - if flag != 0 { - kcp.parse_fastack(maxack) - } - return 0 } @@ -284,42 +199,16 @@ func (kcp *KCP) flush() { kcp.SetState(StateTerminating) } - current := kcp.current - // flush acknowledges kcp.receivingWorker.Flush() + kcp.sendingWorker.Flush() - // calculate window size - cwnd := kcp.snd_una + kcp.snd_wnd - if cwnd > kcp.rmt_wnd { - cwnd = kcp.rmt_wnd - } - if kcp.congestionControl && cwnd > kcp.snd_una+kcp.cwnd { - cwnd = kcp.snd_una + kcp.cwnd - } - - for !kcp.snd_queue.IsEmpty() && _itimediff(kcp.snd_nxt, cwnd) < 0 { - seg := kcp.snd_queue.Pop() - seg.Conv = kcp.conv - seg.Number = kcp.snd_nxt - seg.timeout = current - seg.ackSkipped = 0 - seg.transmit = 0 - kcp.snd_buf.Push(seg) - kcp.snd_nxt++ - } - - // flush data segments - if kcp.snd_buf.Flush() { - kcp.sendingUpdated = false - } - - if kcp.sendingUpdated || kcp.receivingWorker.PingNecessary() || _itimediff(kcp.current, kcp.lastPingTime) >= 5000 { + if kcp.sendingWorker.PingNecessary() || kcp.receivingWorker.PingNecessary() || _itimediff(kcp.current, kcp.lastPingTime) >= 5000 { seg := &CmdOnlySegment{ Conv: kcp.conv, Cmd: SegmentCommandPing, ReceivinNext: kcp.receivingWorker.nextNumber, - SendingNext: kcp.snd_una, + SendingNext: kcp.sendingWorker.firstUnacknowledged, } if kcp.state == StateReadyToClose { seg.Opt = SegmentOptionClose @@ -334,23 +223,6 @@ func (kcp *KCP) flush() { } -func (kcp *KCP) HandleLost(lost bool) { - if !kcp.congestionControl { - return - } - if lost { - kcp.cwnd = 3 * kcp.cwnd / 4 - } else { - kcp.cwnd += kcp.cwnd / 4 - } - if kcp.cwnd < 4 { - kcp.cwnd = 4 - } - if kcp.cwnd > kcp.snd_wnd { - kcp.cwnd = kcp.snd_wnd - } -} - // Update updates state (call it repeatedly, every 10ms-100ms), or you can ask // ikcp_check when to call it again (without ikcp_input/_send calling). // 'current' - current timestamp in millisec. @@ -358,13 +230,3 @@ func (kcp *KCP) Update(current uint32) { kcp.current = current kcp.flush() } - -// WaitSnd gets how many packet is waiting to be sent -func (kcp *KCP) WaitSnd() uint32 { - return uint32(kcp.snd_buf.Len()) + kcp.snd_queue.Len() -} - -func (this *KCP) ClearSendQueue() { - this.snd_queue.Clear() - this.snd_buf.Clear(0xFFFFFFFF) -} diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index efe0cd61..e2fc2ae1 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -1,5 +1,11 @@ package kcp +import ( + "sync" + + "github.com/v2ray/v2ray-core/common/alloc" +) + type SendingWindow struct { start uint32 cap uint32 @@ -10,19 +16,21 @@ type SendingWindow struct { prev []uint32 next []uint32 - kcp *KCP + writer SegmentWriter + onPacketLoss func(bool) } -func NewSendingWindow(kcp *KCP, size uint32) *SendingWindow { +func NewSendingWindow(size uint32, writer SegmentWriter, onPacketLoss func(bool)) *SendingWindow { window := &SendingWindow{ - start: 0, - cap: size, - len: 0, - last: 0, - data: make([]*DataSegment, size), - prev: make([]uint32, size), - next: make([]uint32, size), - kcp: kcp, + start: 0, + cap: size, + len: 0, + last: 0, + data: make([]*DataSegment, size), + prev: make([]uint32, size), + next: make([]uint32, size), + writer: writer, + onPacketLoss: onPacketLoss, } return window } @@ -102,15 +110,12 @@ func (this *SendingWindow) HandleFastAck(number uint32) { } } -func (this *SendingWindow) Flush() bool { +func (this *SendingWindow) Flush(current uint32, resend uint32, rto uint32) { if this.Len() == 0 { - return false + return } - current := this.kcp.current - resent := this.kcp.fastresend lost := false - segSent := false for i := this.start; ; i = this.next[i] { segment := this.data[i] @@ -118,39 +123,29 @@ func (this *SendingWindow) Flush() bool { if segment.transmit == 0 { needsend = true segment.transmit++ - segment.timeout = current + this.kcp.rx_rto + segment.timeout = current + rto } else if _itimediff(current, segment.timeout) >= 0 { needsend = true segment.transmit++ - segment.timeout = current + this.kcp.rx_rto + segment.timeout = current + rto lost = true - } else if segment.ackSkipped >= resent { + } else if segment.ackSkipped >= resend { needsend = true segment.transmit++ segment.ackSkipped = 0 - segment.timeout = current + this.kcp.rx_rto + segment.timeout = current + rto lost = true } if needsend { - segment.Timestamp = current - segment.SendingNext = this.kcp.snd_una - segment.Opt = 0 - if this.kcp.state == StateReadyToClose { - segment.Opt = SegmentOptionClose - } - - this.kcp.output.Write(segment) - segSent = true + this.writer.Write(segment) } if i == this.last { break } } - this.kcp.HandleLost(lost) - - return segSent + this.onPacketLoss(lost) } type SendingQueue struct { @@ -211,3 +206,175 @@ func (this *SendingQueue) Clear() { func (this *SendingQueue) Len() uint32 { return this.len } + +type SendingWorker struct { + sync.Mutex + kcp *KCP + window *SendingWindow + queue *SendingQueue + windowSize uint32 + firstUnacknowledged uint32 + nextNumber uint32 + remoteNextNumber uint32 + controlWindow uint32 + fastResend uint32 + updated bool +} + +func NewSendingWorker(kcp *KCP) *SendingWorker { + worker := &SendingWorker{ + kcp: kcp, + queue: NewSendingQueue(effectiveConfig.GetSendingQueueSize()), + fastResend: 2, + remoteNextNumber: 32, + windowSize: effectiveConfig.GetSendingWindowSize(), + controlWindow: effectiveConfig.GetSendingWindowSize(), + } + worker.window = NewSendingWindow(effectiveConfig.GetSendingWindowSize(), worker, worker.OnPacketLoss) + return worker +} + +func (this *SendingWorker) ProcessReceivingNext(nextNumber uint32) { + this.Lock() + defer this.Unlock() + + this.window.Clear(nextNumber) + this.FindFirstUnacknowledged() +} + +// @Private +func (this *SendingWorker) FindFirstUnacknowledged() { + prevUna := this.firstUnacknowledged + if this.window.Len() > 0 { + this.firstUnacknowledged = this.window.First().Number + } else { + this.firstUnacknowledged = this.nextNumber + } + if this.firstUnacknowledged != prevUna { + this.updated = true + } +} + +func (this *SendingWorker) ProcessAck(number uint32) { + if number-this.firstUnacknowledged > this.windowSize { + return + } + + this.Lock() + defer this.Unlock() + this.window.Remove(number - this.firstUnacknowledged) + this.FindFirstUnacknowledged() +} + +func (this *SendingWorker) ProcessAckSegment(seg *AckSegment) { + if this.remoteNextNumber < seg.ReceivingWindow { + this.remoteNextNumber = seg.ReceivingWindow + } + this.ProcessReceivingNext(seg.ReceivingNext) + var maxack uint32 + for i := 0; i < int(seg.Count); i++ { + timestamp := seg.TimestampList[i] + number := seg.NumberList[i] + if this.kcp.current-timestamp > 10000 { + this.kcp.update_ack(int32(this.kcp.current - timestamp)) + } + this.ProcessAck(number) + if maxack < number { + maxack = number + } + } + this.Lock() + this.window.HandleFastAck(maxack) + this.Unlock() +} + +func (this *SendingWorker) Push(b []byte) int { + nBytes := 0 + for len(b) > 0 && !this.queue.IsFull() { + var size int + if len(b) > int(this.kcp.mss) { + size = int(this.kcp.mss) + } else { + size = len(b) + } + seg := &DataSegment{ + Data: alloc.NewSmallBuffer().Clear().Append(b[:size]), + } + this.Lock() + this.queue.Push(seg) + this.Unlock() + b = b[size:] + nBytes += size + } + return nBytes +} + +func (this *SendingWorker) Write(seg ISegment) { + dataSeg := seg.(*DataSegment) + + dataSeg.Conv = this.kcp.conv + dataSeg.Timestamp = this.kcp.current + dataSeg.SendingNext = this.firstUnacknowledged + dataSeg.Opt = 0 + if this.kcp.state == StateReadyToClose { + dataSeg.Opt = SegmentOptionClose + } + + this.kcp.output.Write(dataSeg) + this.updated = false +} + +func (this *SendingWorker) PingNecessary() bool { + return this.updated +} + +func (this *SendingWorker) OnPacketLoss(lost bool) { + if !effectiveConfig.Congestion { + return + } + + if lost { + this.controlWindow = 3 * this.controlWindow / 4 + } else { + this.controlWindow += this.controlWindow / 4 + } + if this.controlWindow < 4 { + this.controlWindow = 4 + } + if this.controlWindow > this.windowSize { + this.controlWindow = this.windowSize + } +} + +func (this *SendingWorker) Flush() { + this.Lock() + defer this.Unlock() + + cwnd := this.firstUnacknowledged + this.windowSize + if cwnd > this.remoteNextNumber { + cwnd = this.remoteNextNumber + } + if effectiveConfig.Congestion && cwnd > this.firstUnacknowledged+this.controlWindow { + cwnd = this.firstUnacknowledged + this.controlWindow + } + + for !this.queue.IsEmpty() && _itimediff(this.nextNumber, cwnd) < 0 { + seg := this.queue.Pop() + seg.Number = this.nextNumber + seg.timeout = this.kcp.current + seg.ackSkipped = 0 + seg.transmit = 0 + this.window.Push(seg) + this.nextNumber++ + } + + this.window.Flush(this.kcp.current, this.kcp.fastresend, this.kcp.rx_rto) +} + +func (this *SendingWorker) CloseWrite() { + this.Lock() + defer this.Unlock() + + this.window.Clear(0xFFFFFFFF) + this.queue.Clear() +}