diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 860af5eb..a6ccd8ca 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -176,6 +176,7 @@ func (this *Connection) Read(b []byte) (int, error) { if nBytes > 0 { return nBytes, nil } + var timer *time.Timer if !this.rd.IsZero() { duration := this.rd.Sub(time.Now()) @@ -229,7 +230,7 @@ func (this *Connection) SetState(state State) { current := this.Elapsed() atomic.StoreInt32((*int32)(&this.state), int32(state)) atomic.StoreUint32(&this.stateBeginTime, current) - log.Info("KCP|Connection: Entering state ", state, " at ", current) + log.Info("KCP|Connection: #", this.conv, " entering state ", state, " at ", current) switch state { case StateReadyToClose: @@ -429,12 +430,18 @@ func (this *Connection) flush() { if this.State() == StateActive && current-atomic.LoadUint32(&this.lastIncomingTime) >= 30000 { this.Close() } + if this.State() == StateReadyToClose && this.sendingWorker.IsEmpty() { + this.SetState(StateTerminating) + } if this.State() == StateTerminating { - this.output.Write(&CmdOnlySegment{ - Conv: this.conv, - Cmd: SegmentCommandTerminated, - }) + log.Debug("KCP|Connection: #", this.conv, " sending terminating cmd.") + seg := NewCmdOnlySegment() + defer seg.Release() + + seg.Conv = this.conv + seg.Cmd = SegmentCommandTerminated + this.output.Write(seg) this.output.Flush() if current-atomic.LoadUint32(&this.stateBeginTime) > 8000 { diff --git a/transport/internet/kcp/kcp_test.go b/transport/internet/kcp/kcp_test.go new file mode 100644 index 00000000..e6dde669 --- /dev/null +++ b/transport/internet/kcp/kcp_test.go @@ -0,0 +1,77 @@ +package kcp_test + +import ( + "crypto/rand" + "io" + "sync" + "testing" + "time" + + v2net "github.com/v2ray/v2ray-core/common/net" + v2nettesting "github.com/v2ray/v2ray-core/common/net/testing" + "github.com/v2ray/v2ray-core/testing/assert" + . "github.com/v2ray/v2ray-core/transport/internet/kcp" +) + +func TestDialAndListen(t *testing.T) { + assert := assert.On(t) + + port := v2nettesting.PickPort() + listerner, err := NewListener(v2net.LocalHostIP, port) + assert.Error(err).IsNil() + + go func() { + for { + conn, err := listerner.Accept() + if err != nil { + break + } + go func() { + payload := make([]byte, 1024) + for { + nBytes, err := conn.Read(payload) + if err != nil { + break + } + for idx, b := range payload[:nBytes] { + payload[idx] = b ^ 'c' + } + conn.Write(payload[:nBytes]) + } + conn.Close() + }() + } + }() + + wg := new(sync.WaitGroup) + for i := 0; i < 10; i++ { + clientConn, err := DialKCP(v2net.LocalHostIP, v2net.UDPDestination(v2net.LocalHostIP, port)) + assert.Error(err).IsNil() + wg.Add(1) + + go func() { + clientSend := make([]byte, 1024*1024) + rand.Read(clientSend) + clientConn.Write(clientSend) + + clientReceived := make([]byte, 1024*1024) + nBytes, _ := io.ReadFull(clientConn, clientReceived) + assert.Int(nBytes).Equals(len(clientReceived)) + clientConn.Close() + + clientExpected := make([]byte, 1024*1024) + for idx, b := range clientSend { + clientExpected[idx] = b ^ 'c' + } + assert.Bytes(clientReceived).Equals(clientExpected) + + wg.Done() + }() + } + + wg.Wait() + time.Sleep(15 * time.Second) + assert.Int(listerner.ActiveConnections()).Equals(0) + + listerner.Close() +} diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index dc37bbec..2d948e03 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -63,6 +63,7 @@ func (this *Listener) OnReceive(payload *alloc.Buffer, src v2net.Destination) { sourceId := src.NetAddr() + "|" + serial.Uint16ToString(conv) conn, found := this.sessions[sourceId] if !found { + log.Debug("KCP|Listener: Creating session with id(", sourceId, ") from ", src) writer := &Writer{ id: sourceId, hub: this.hub, @@ -94,6 +95,7 @@ func (this *Listener) Remove(dest string) { if !this.running { return } + log.Debug("KCP|Listener: Removing session ", dest) delete(this.sessions, dest) } @@ -130,6 +132,13 @@ func (this *Listener) Close() error { return nil } +func (this *Listener) ActiveConnections() int { + this.Lock() + defer this.Unlock() + + return len(this.sessions) +} + // Addr returns the listener's network address, The Addr returned is shared by all invocations of Addr, so do not modify it. func (this *Listener) Addr() net.Addr { return this.localAddr diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index 03fe998b..2fb3899f 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -40,6 +40,10 @@ func (this *SendingWindow) Len() int { return int(this.len) } +func (this *SendingWindow) IsEmpty() bool { + return this.len == 0 +} + func (this *SendingWindow) Size() uint32 { return this.cap } @@ -64,7 +68,7 @@ func (this *SendingWindow) First() *DataSegment { } func (this *SendingWindow) Clear(una uint32) { - for this.Len() > 0 && this.data[this.start].Number < una { + for !this.IsEmpty() && this.data[this.start].Number < una { this.Remove(0) } } @@ -121,7 +125,7 @@ func (this *SendingWindow) HandleFastAck(number uint32) { } func (this *SendingWindow) Flush(current uint32, resend uint32, rto uint32, maxInFlightSize uint32) { - if this.Len() == 0 { + if this.IsEmpty() { return } @@ -266,7 +270,7 @@ func (this *SendingWorker) ProcessReceivingNextWithoutLock(nextNumber uint32) { // @Private func (this *SendingWorker) FindFirstUnacknowledged() { prevUna := this.firstUnacknowledged - if this.window.Len() > 0 { + if !this.window.IsEmpty() { this.firstUnacknowledged = this.window.First().Number } else { this.firstUnacknowledged = this.nextNumber @@ -410,3 +414,10 @@ func (this *SendingWorker) CloseWrite() { this.window.Clear(0xFFFFFFFF) this.queue.Clear() } + +func (this *SendingWorker) IsEmpty() bool { + this.RLock() + defer this.RUnlock() + + return this.window.IsEmpty() && this.queue.IsEmpty() +}