diff --git a/transport/internet/kcp/config.go b/transport/internet/kcp/config.go index 31844041..a7607cfb 100644 --- a/transport/internet/kcp/config.go +++ b/transport/internet/kcp/config.go @@ -1,10 +1,10 @@ package kcp type Config struct { - Mtu int // Maximum transmission unit - Tti int - UplinkCapacity int - DownlinkCapacity int + Mtu uint32 // Maximum transmission unit + Tti uint32 + UplinkCapacity uint32 + DownlinkCapacity uint32 Congestion bool WriteBuffer int } @@ -13,11 +13,11 @@ func (this *Config) Apply() { effectiveConfig = *this } -func (this *Config) GetSendingWindowSize() int { +func (this *Config) GetSendingWindowSize() uint32 { return this.UplinkCapacity * 1024 * 1024 / this.Mtu / (1000 / this.Tti) } -func (this *Config) GetReceivingWindowSize() int { +func (this *Config) GetReceivingWindowSize() uint32 { return this.DownlinkCapacity * 1024 * 1024 / this.Mtu / (1000 / this.Tti) } diff --git a/transport/internet/kcp/config_json.go b/transport/internet/kcp/config_json.go index c400078a..b476112a 100644 --- a/transport/internet/kcp/config_json.go +++ b/transport/internet/kcp/config_json.go @@ -11,11 +11,11 @@ import ( func (this *Config) UnmarshalJSON(data []byte) error { type JSONConfig struct { - Mtu *int `json:"mtu"` - Tti *int `json:"tti"` - UpCap *int `json:"uplinkCapacity"` - DownCap *int `json:"downlinkCapacity"` - Congestion *bool `json:"congestion"` + Mtu *uint32 `json:"mtu"` + Tti *uint32 `json:"tti"` + UpCap *uint32 `json:"uplinkCapacity"` + DownCap *uint32 `json:"downlinkCapacity"` + Congestion *bool `json:"congestion"` } jsonConfig := new(JSONConfig) if err := json.Unmarshal(data, &jsonConfig); err != nil { @@ -39,7 +39,7 @@ func (this *Config) UnmarshalJSON(data []byte) error { } if jsonConfig.UpCap != nil { upCap := *jsonConfig.UpCap - if upCap < 0 { + if upCap == 0 { log.Error("KCP|Config: Invalid uplink capacity: ", upCap) return common.ErrBadConfiguration } @@ -47,7 +47,7 @@ func (this *Config) UnmarshalJSON(data []byte) error { } if jsonConfig.DownCap != nil { downCap := *jsonConfig.DownCap - if downCap < 0 { + if downCap == 0 { log.Error("KCP|Config: Invalid downlink capacity: ", downCap) return common.ErrBadConfiguration } diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 30cee644..64c5f1f1 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -20,7 +20,7 @@ var ( ) const ( - headerSize = 2 + headerSize uint32 = 2 ) type Command byte @@ -65,7 +65,7 @@ type Connection struct { writer io.WriteCloser since int64 terminateOnce signal.Once - writeBufferSize int + writeBufferSize uint32 } // NewConnection create a new KCP connection between local and remote. @@ -78,12 +78,11 @@ func NewConnection(conv uint32, writerCloser io.WriteCloser, local *net.UDPAddr, conn.writer = writerCloser conn.since = nowMillisec() - mtu := uint32(effectiveConfig.Mtu - block.HeaderSize() - headerSize) - conn.kcp = NewKCP(conv, mtu, conn.output) - conn.kcp.WndSize(effectiveConfig.GetSendingWindowSize(), effectiveConfig.GetReceivingWindowSize()) + mtu := effectiveConfig.Mtu - uint32(block.HeaderSize()) - headerSize + conn.kcp = NewKCP(conv, mtu, effectiveConfig.GetSendingWindowSize(), effectiveConfig.GetReceivingWindowSize(), conn.output) conn.kcp.NoDelay(effectiveConfig.Tti, 2, effectiveConfig.Congestion) conn.kcp.current = conn.Elapsed() - conn.writeBufferSize = effectiveConfig.WriteBuffer / effectiveConfig.Mtu + conn.writeBufferSize = uint32(effectiveConfig.WriteBuffer) / effectiveConfig.Mtu go conn.updateTask() diff --git a/transport/internet/kcp/kcp.go b/transport/internet/kcp/kcp.go index 662245f4..124be025 100644 --- a/transport/internet/kcp/kcp.go +++ b/transport/internet/kcp/kcp.go @@ -149,7 +149,7 @@ type KCP struct { snd_queue []*Segment rcv_queue []*Segment snd_buf []*Segment - rcv_buf []*Segment + rcv_buf *ReceivingWindow acklist []uint32 @@ -161,11 +161,11 @@ type KCP struct { // NewKCP create a new kcp control object, 'conv' must equal in two endpoint // from the same connection. -func NewKCP(conv uint32, mtu uint32, output Output) *KCP { +func NewKCP(conv uint32, mtu uint32, sendingWindowSize uint32, receivingWindowSize uint32, output Output) *KCP { kcp := new(KCP) kcp.conv = conv - kcp.snd_wnd = IKCP_WND_SND - kcp.rcv_wnd = IKCP_WND_RCV + kcp.snd_wnd = sendingWindowSize + kcp.rcv_wnd = receivingWindowSize kcp.rmt_wnd = IKCP_WND_RCV kcp.mtu = mtu kcp.mss = kcp.mtu - IKCP_OVERHEAD @@ -176,6 +176,7 @@ func NewKCP(conv uint32, mtu uint32, output Output) *KCP { kcp.ssthresh = IKCP_THRESH_INIT kcp.dead_link = IKCP_DEADLINK kcp.output = output + kcp.rcv_buf = NewReceivingWindow(receivingWindowSize) return kcp } @@ -205,19 +206,7 @@ func (kcp *KCP) Recv(buffer []byte) (n int) { } kcp.rcv_queue = kcp.rcv_queue[count:] - // move available data from rcv_buf -> rcv_queue - count = 0 - for _, seg := range kcp.rcv_buf { - if seg.sn == kcp.rcv_nxt && len(kcp.rcv_queue) < int(kcp.rcv_wnd) { - kcp.rcv_queue = append(kcp.rcv_queue, seg) - kcp.rcv_nxt++ - count++ - } else { - break - } - } - kcp.rcv_buf = kcp.rcv_buf[count:] - + kcp.DumpReceivingBuf() // fast recover if len(kcp.rcv_queue) < int(kcp.rcv_wnd) && fast_recover { // ready to send back IKCP_CMD_WINS in ikcp_flush @@ -227,6 +216,20 @@ func (kcp *KCP) Recv(buffer []byte) (n int) { return } +// DumpReceivingBuf moves available data from rcv_buf -> rcv_queue +// @Private +func (kcp *KCP) DumpReceivingBuf() { + for { + seg := kcp.rcv_buf.RemoveFirst() + if seg == nil { + break + } + kcp.rcv_queue = append(kcp.rcv_queue, seg) + kcp.rcv_buf.Advance() + kcp.rcv_nxt++ + } +} + // Send is user/upper level send, returns below zero for error func (kcp *KCP) Send(buffer []byte) int { var count int @@ -357,43 +360,12 @@ func (kcp *KCP) parse_data(newseg *Segment) { return } - n := len(kcp.rcv_buf) - 1 - insert_idx := 0 - repeat := false - for i := n; i >= 0; i-- { - seg := kcp.rcv_buf[i] - if seg.sn == sn { - repeat = true - break - } - if _itimediff(sn, seg.sn) > 0 { - insert_idx = i + 1 - break - } + idx := sn - kcp.rcv_nxt + if !kcp.rcv_buf.Set(idx, newseg) { + newseg.Release() } - if !repeat { - if insert_idx == n+1 { - kcp.rcv_buf = append(kcp.rcv_buf, newseg) - } else { - kcp.rcv_buf = append(kcp.rcv_buf, &Segment{}) - copy(kcp.rcv_buf[insert_idx+1:], kcp.rcv_buf[insert_idx:]) - kcp.rcv_buf[insert_idx] = newseg - } - } - - // move available data from rcv_buf -> rcv_queue - count := 0 - for k, seg := range kcp.rcv_buf { - if seg.sn == kcp.rcv_nxt && len(kcp.rcv_queue) < int(kcp.rcv_wnd) { - kcp.rcv_queue = append(kcp.rcv_queue, kcp.rcv_buf[k]) - kcp.rcv_nxt++ - count++ - } else { - break - } - } - kcp.rcv_buf = kcp.rcv_buf[count:] + kcp.DumpReceivingBuf() } // Input when you received a low level packet (eg. UDP packet), call it @@ -790,15 +762,9 @@ func (kcp *KCP) Check(current uint32) uint32 { // interval: internal update timer interval in millisec, default is 100ms // resend: 0:disable fast resend(default), 1:enable fast resend // nc: 0:normal congestion control(default), 1:disable congestion control -func (kcp *KCP) NoDelay(interval, resend int, congestionControl bool) int { - if interval >= 0 { - if interval > 5000 { - interval = 5000 - } else if interval < 10 { - interval = 10 - } - kcp.interval = uint32(interval) - } +func (kcp *KCP) NoDelay(interval uint32, resend int, congestionControl bool) int { + kcp.interval = interval + if resend >= 0 { kcp.fastresend = int32(resend) } @@ -806,20 +772,9 @@ func (kcp *KCP) NoDelay(interval, resend int, congestionControl bool) int { return 0 } -// WndSize sets maximum window size: sndwnd=32, rcvwnd=32 by default -func (kcp *KCP) WndSize(sndwnd, rcvwnd int) int { - if sndwnd > 0 { - kcp.snd_wnd = uint32(sndwnd) - } - if rcvwnd > 0 { - kcp.rcv_wnd = uint32(rcvwnd) - } - return 0 -} - // WaitSnd gets how many packet is waiting to be sent -func (kcp *KCP) WaitSnd() int { - return len(kcp.snd_buf) + len(kcp.snd_queue) +func (kcp *KCP) WaitSnd() uint32 { + return uint32(len(kcp.snd_buf) + len(kcp.snd_queue)) } func (this *KCP) ClearSendQueue() { diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go new file mode 100644 index 00000000..f3bbc8c7 --- /dev/null +++ b/transport/internet/kcp/receiving.go @@ -0,0 +1,53 @@ +package kcp + +type ReceivingWindow struct { + start uint32 + size uint32 + list []*Segment +} + +func NewReceivingWindow(size uint32) *ReceivingWindow { + return &ReceivingWindow{ + start: 0, + size: size, + list: make([]*Segment, size), + } +} + +func (this *ReceivingWindow) Size() uint32 { + return this.size +} + +func (this *ReceivingWindow) Position(idx uint32) uint32 { + return (idx + this.start) % this.size +} + +func (this *ReceivingWindow) Set(idx uint32, value *Segment) bool { + pos := this.Position(idx) + if this.list[pos] != nil { + return false + } + this.list[pos] = value + return true +} + +func (this *ReceivingWindow) Remove(idx uint32) *Segment { + pos := this.Position(idx) + if this.list[pos] == nil { + return nil + } + e := this.list[pos] + this.list[pos] = nil + return e +} + +func (this *ReceivingWindow) RemoveFirst() *Segment { + return this.Remove(0) +} + +func (this *ReceivingWindow) Advance() { + this.start++ + if this.start == this.size { + this.start = 0 + } +} diff --git a/transport/internet/kcp/receiving_test.go b/transport/internet/kcp/receiving_test.go new file mode 100644 index 00000000..ef2a6e72 --- /dev/null +++ b/transport/internet/kcp/receiving_test.go @@ -0,0 +1,32 @@ +package kcp_test + +import ( + "testing" + + "github.com/v2ray/v2ray-core/testing/assert" + . "github.com/v2ray/v2ray-core/transport/internet/kcp" +) + +func TestRecivingWindow(t *testing.T) { + assert := assert.On(t) + + window := NewReceivingWindow(3) + + seg0 := &Segment{} + seg1 := &Segment{} + seg2 := &Segment{} + seg3 := &Segment{} + + assert.Bool(window.Set(0, seg0)).IsTrue() + assert.Pointer(window.RemoveFirst()).Equals(seg0) + + assert.Bool(window.Set(1, seg1)).IsTrue() + assert.Bool(window.Set(2, seg2)).IsTrue() + + window.Advance() + assert.Bool(window.Set(2, seg3)).IsTrue() + + assert.Pointer(window.RemoveFirst()).Equals(seg1) + assert.Pointer(window.Remove(1)).Equals(seg2) + assert.Pointer(window.Remove(2)).Equals(seg3) +}