diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 5c90447b..b46ebc92 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -116,7 +116,7 @@ func (v *RoundTripInfo) SmoothedTime() uint32 { } type Updater struct { - interval time.Duration + interval int64 shouldContinue predicate.Predicate shouldTerminate predicate.Predicate updateFunc func() @@ -125,7 +125,7 @@ type Updater struct { func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater { u := &Updater{ - interval: time.Duration(interval) * time.Millisecond, + interval: int64(time.Duration(interval) * time.Millisecond), shouldContinue: shouldContinue, shouldTerminate: shouldTerminate, updateFunc: updateFunc, @@ -149,11 +149,19 @@ func (v *Updater) Run() { } for v.shouldContinue() { v.updateFunc() - time.Sleep(v.interval) + time.Sleep(v.Interval()) } } } +func (u *Updater) Interval() time.Duration { + return time.Duration(atomic.LoadInt64(&u.interval)) +} + +func (u *Updater) SetInterval(d time.Duration) { + atomic.StoreInt64(&u.interval, int64(d)) +} + type SystemConnection interface { net.Conn Id() internal.ConnectionID @@ -342,14 +350,14 @@ func (v *Connection) SetState(state State) { case StateTerminating: v.receivingWorker.CloseRead() v.sendingWorker.CloseWrite() - v.pingUpdater.interval = time.Second + v.pingUpdater.SetInterval(time.Second) case StatePeerTerminating: v.sendingWorker.CloseWrite() - v.pingUpdater.interval = time.Second + v.pingUpdater.SetInterval(time.Second) case StateTerminated: v.receivingWorker.CloseRead() v.sendingWorker.CloseWrite() - v.pingUpdater.interval = time.Second + v.pingUpdater.SetInterval(time.Second) v.dataUpdater.WakeUp() v.pingUpdater.WakeUp() go v.Terminate() @@ -491,7 +499,7 @@ func (v *Connection) Input(segments []Segment) { case *DataSegment: v.HandleOption(seg.Option) v.receivingWorker.ProcessSegment(seg) - if seg.Number == v.receivingWorker.nextNumber { + if v.receivingWorker.IsDataAvailable() { v.OnDataInput() } v.dataUpdater.WakeUp() @@ -573,8 +581,8 @@ func (v *Connection) Ping(current uint32, cmd Command) { seg := NewCmdOnlySegment() seg.Conv = v.conv seg.Cmd = cmd - seg.ReceivinNext = v.receivingWorker.nextNumber - seg.SendingNext = v.sendingWorker.firstUnacknowledged + seg.ReceivinNext = v.receivingWorker.NextNumber() + seg.SendingNext = v.sendingWorker.FirstUnacknowledged() seg.PeerRTO = v.roundTrip.Timeout() if v.State() == StateReadyToClose { seg.Option = SegmentOptionClose diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 477c7b65..4b3ebed9 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -79,7 +79,7 @@ func (o *ServerConnection) Id() internal.ConnectionID { // Listener defines a server listening for connections type Listener struct { sync.Mutex - running bool + closed chan bool sessions map[ConnectionID]*Connection awaitingConns chan *Connection hub *udp.Hub @@ -116,7 +116,7 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen }, sessions: make(map[ConnectionID]*Connection), awaitingConns: make(chan *Connection, 64), - running: true, + closed: make(chan bool), config: kcpSettings, } if options.Stream != nil && options.Stream.HasSecuritySettings() { @@ -134,7 +134,9 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen if err != nil { return nil, err } + l.Lock() l.hub = hub + l.Unlock() log.Info("KCP|Listener: listening on ", address, ":", port) return l, nil } @@ -148,12 +150,15 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina return } - if !v.running { + select { + case <-v.closed: return + default: } + v.Lock() defer v.Unlock() - if !v.running { + if v.hub == nil { return } if payload.Len() < 4 { @@ -208,24 +213,22 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina } func (v *Listener) Remove(id ConnectionID) { - if !v.running { + select { + case <-v.closed: return + default: + v.Lock() + delete(v.sessions, id) + v.Unlock() } - v.Lock() - defer v.Unlock() - if !v.running { - return - } - delete(v.sessions, id) } // Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn. func (v *Listener) Accept() (internet.Connection, error) { for { - if !v.running { - return nil, ErrClosedListener - } select { + case <-v.closed: + return nil, ErrClosedListener case conn, open := <-v.awaitingConns: if !open { break @@ -243,13 +246,15 @@ func (v *Listener) Accept() (internet.Connection, error) { // Close stops listening on the UDP address. Already Accepted connections are not closed. func (v *Listener) Close() error { - if !v.running { + select { + case <-v.closed: return ErrClosedListener + default: } v.Lock() defer v.Unlock() - v.running = false + close(v.closed) close(v.awaitingConns) for _, conn := range v.sessions { go conn.Terminate() diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go index b85ccd9f..b5a0cc80 100644 --- a/transport/internet/kcp/receiving.go +++ b/transport/internet/kcp/receiving.go @@ -48,6 +48,10 @@ func (v *ReceivingWindow) RemoveFirst() *DataSegment { return v.Remove(0) } +func (w *ReceivingWindow) HasFirst() bool { + return w.list[w.Position(0)] != nil +} + func (v *ReceivingWindow) Advance() { v.start++ if v.start == v.size { @@ -163,7 +167,9 @@ func NewReceivingWorker(kcp *Connection) *ReceivingWorker { } func (v *ReceivingWorker) Release() { + v.Lock() v.leftOver.Release() + v.Unlock() } func (v *ReceivingWorker) ProcessSendingNext(number uint32) { @@ -228,6 +234,19 @@ func (v *ReceivingWorker) Read(b []byte) int { return total } +func (w *ReceivingWorker) IsDataAvailable() bool { + w.RLock() + defer w.RUnlock() + return w.window.HasFirst() +} + +func (w *ReceivingWorker) NextNumber() uint32 { + w.RLock() + defer w.RUnlock() + + return w.nextNumber +} + func (v *ReceivingWorker) Flush(current uint32) { v.Lock() defer v.Unlock() @@ -250,5 +269,8 @@ func (v *ReceivingWorker) CloseRead() { } func (v *ReceivingWorker) UpdateNecessary() bool { + v.RLock() + defer v.RUnlock() + return len(v.acklist.numbers) > 0 } diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index 460d51a5..f54b8559 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -207,7 +207,9 @@ func NewSendingWorker(kcp *Connection) *SendingWorker { } func (v *SendingWorker) Release() { + v.Lock() v.window.Release() + v.Unlock() } func (v *SendingWorker) ProcessReceivingNext(nextNumber uint32) { @@ -336,7 +338,6 @@ func (v *SendingWorker) OnPacketLoss(lossRate uint32) { func (v *SendingWorker) Flush(current uint32) { v.Lock() - defer v.Unlock() cwnd := v.firstUnacknowledged + v.conn.Config.GetSendingInFlightSize() if cwnd > v.remoteNextNumber { @@ -348,11 +349,17 @@ func (v *SendingWorker) Flush(current uint32) { if !v.window.IsEmpty() { v.window.Flush(current, v.conn.roundTrip.Timeout(), cwnd) - } else if v.firstUnacknowledgedUpdated { - v.conn.Ping(current, CommandPing) + v.firstUnacknowledgedUpdated = false } + updated := v.firstUnacknowledgedUpdated v.firstUnacknowledgedUpdated = false + + v.Unlock() + + if updated { + v.conn.Ping(current, CommandPing) + } } func (v *SendingWorker) CloseWrite() { @@ -372,3 +379,10 @@ func (v *SendingWorker) IsEmpty() bool { func (v *SendingWorker) UpdateNecessary() bool { return !v.IsEmpty() } + +func (w *SendingWorker) FirstUnacknowledged() uint32 { + w.RLock() + defer w.RUnlock() + + return w.firstUnacknowledged +}