fix race conditions in kcp

pull/432/head
Darien Raymond 2017-02-18 00:04:25 +01:00
parent 587ada599c
commit ebed271a92
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
4 changed files with 77 additions and 28 deletions

View File

@ -116,7 +116,7 @@ func (v *RoundTripInfo) SmoothedTime() uint32 {
} }
type Updater struct { type Updater struct {
interval time.Duration interval int64
shouldContinue predicate.Predicate shouldContinue predicate.Predicate
shouldTerminate predicate.Predicate shouldTerminate predicate.Predicate
updateFunc func() updateFunc func()
@ -125,7 +125,7 @@ type Updater struct {
func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater { func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater {
u := &Updater{ u := &Updater{
interval: time.Duration(interval) * time.Millisecond, interval: int64(time.Duration(interval) * time.Millisecond),
shouldContinue: shouldContinue, shouldContinue: shouldContinue,
shouldTerminate: shouldTerminate, shouldTerminate: shouldTerminate,
updateFunc: updateFunc, updateFunc: updateFunc,
@ -149,11 +149,19 @@ func (v *Updater) Run() {
} }
for v.shouldContinue() { for v.shouldContinue() {
v.updateFunc() v.updateFunc()
time.Sleep(v.interval) time.Sleep(v.Interval())
} }
} }
} }
func (u *Updater) Interval() time.Duration {
return time.Duration(atomic.LoadInt64(&u.interval))
}
func (u *Updater) SetInterval(d time.Duration) {
atomic.StoreInt64(&u.interval, int64(d))
}
type SystemConnection interface { type SystemConnection interface {
net.Conn net.Conn
Id() internal.ConnectionID Id() internal.ConnectionID
@ -342,14 +350,14 @@ func (v *Connection) SetState(state State) {
case StateTerminating: case StateTerminating:
v.receivingWorker.CloseRead() v.receivingWorker.CloseRead()
v.sendingWorker.CloseWrite() v.sendingWorker.CloseWrite()
v.pingUpdater.interval = time.Second v.pingUpdater.SetInterval(time.Second)
case StatePeerTerminating: case StatePeerTerminating:
v.sendingWorker.CloseWrite() v.sendingWorker.CloseWrite()
v.pingUpdater.interval = time.Second v.pingUpdater.SetInterval(time.Second)
case StateTerminated: case StateTerminated:
v.receivingWorker.CloseRead() v.receivingWorker.CloseRead()
v.sendingWorker.CloseWrite() v.sendingWorker.CloseWrite()
v.pingUpdater.interval = time.Second v.pingUpdater.SetInterval(time.Second)
v.dataUpdater.WakeUp() v.dataUpdater.WakeUp()
v.pingUpdater.WakeUp() v.pingUpdater.WakeUp()
go v.Terminate() go v.Terminate()
@ -491,7 +499,7 @@ func (v *Connection) Input(segments []Segment) {
case *DataSegment: case *DataSegment:
v.HandleOption(seg.Option) v.HandleOption(seg.Option)
v.receivingWorker.ProcessSegment(seg) v.receivingWorker.ProcessSegment(seg)
if seg.Number == v.receivingWorker.nextNumber { if v.receivingWorker.IsDataAvailable() {
v.OnDataInput() v.OnDataInput()
} }
v.dataUpdater.WakeUp() v.dataUpdater.WakeUp()
@ -573,8 +581,8 @@ func (v *Connection) Ping(current uint32, cmd Command) {
seg := NewCmdOnlySegment() seg := NewCmdOnlySegment()
seg.Conv = v.conv seg.Conv = v.conv
seg.Cmd = cmd seg.Cmd = cmd
seg.ReceivinNext = v.receivingWorker.nextNumber seg.ReceivinNext = v.receivingWorker.NextNumber()
seg.SendingNext = v.sendingWorker.firstUnacknowledged seg.SendingNext = v.sendingWorker.FirstUnacknowledged()
seg.PeerRTO = v.roundTrip.Timeout() seg.PeerRTO = v.roundTrip.Timeout()
if v.State() == StateReadyToClose { if v.State() == StateReadyToClose {
seg.Option = SegmentOptionClose seg.Option = SegmentOptionClose

View File

@ -79,7 +79,7 @@ func (o *ServerConnection) Id() internal.ConnectionID {
// Listener defines a server listening for connections // Listener defines a server listening for connections
type Listener struct { type Listener struct {
sync.Mutex sync.Mutex
running bool closed chan bool
sessions map[ConnectionID]*Connection sessions map[ConnectionID]*Connection
awaitingConns chan *Connection awaitingConns chan *Connection
hub *udp.Hub hub *udp.Hub
@ -116,7 +116,7 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
}, },
sessions: make(map[ConnectionID]*Connection), sessions: make(map[ConnectionID]*Connection),
awaitingConns: make(chan *Connection, 64), awaitingConns: make(chan *Connection, 64),
running: true, closed: make(chan bool),
config: kcpSettings, config: kcpSettings,
} }
if options.Stream != nil && options.Stream.HasSecuritySettings() { if options.Stream != nil && options.Stream.HasSecuritySettings() {
@ -134,7 +134,9 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
if err != nil { if err != nil {
return nil, err return nil, err
} }
l.Lock()
l.hub = hub l.hub = hub
l.Unlock()
log.Info("KCP|Listener: listening on ", address, ":", port) log.Info("KCP|Listener: listening on ", address, ":", port)
return l, nil return l, nil
} }
@ -148,12 +150,15 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina
return return
} }
if !v.running { select {
case <-v.closed:
return return
default:
} }
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
if !v.running { if v.hub == nil {
return return
} }
if payload.Len() < 4 { if payload.Len() < 4 {
@ -208,24 +213,22 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina
} }
func (v *Listener) Remove(id ConnectionID) { func (v *Listener) Remove(id ConnectionID) {
if !v.running { select {
case <-v.closed:
return return
default:
v.Lock()
delete(v.sessions, id)
v.Unlock()
} }
v.Lock()
defer v.Unlock()
if !v.running {
return
}
delete(v.sessions, id)
} }
// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn. // Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn.
func (v *Listener) Accept() (internet.Connection, error) { func (v *Listener) Accept() (internet.Connection, error) {
for { for {
if !v.running {
return nil, ErrClosedListener
}
select { select {
case <-v.closed:
return nil, ErrClosedListener
case conn, open := <-v.awaitingConns: case conn, open := <-v.awaitingConns:
if !open { if !open {
break break
@ -243,13 +246,15 @@ func (v *Listener) Accept() (internet.Connection, error) {
// Close stops listening on the UDP address. Already Accepted connections are not closed. // Close stops listening on the UDP address. Already Accepted connections are not closed.
func (v *Listener) Close() error { func (v *Listener) Close() error {
if !v.running { select {
case <-v.closed:
return ErrClosedListener return ErrClosedListener
default:
} }
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
v.running = false close(v.closed)
close(v.awaitingConns) close(v.awaitingConns)
for _, conn := range v.sessions { for _, conn := range v.sessions {
go conn.Terminate() go conn.Terminate()

View File

@ -48,6 +48,10 @@ func (v *ReceivingWindow) RemoveFirst() *DataSegment {
return v.Remove(0) return v.Remove(0)
} }
func (w *ReceivingWindow) HasFirst() bool {
return w.list[w.Position(0)] != nil
}
func (v *ReceivingWindow) Advance() { func (v *ReceivingWindow) Advance() {
v.start++ v.start++
if v.start == v.size { if v.start == v.size {
@ -163,7 +167,9 @@ func NewReceivingWorker(kcp *Connection) *ReceivingWorker {
} }
func (v *ReceivingWorker) Release() { func (v *ReceivingWorker) Release() {
v.Lock()
v.leftOver.Release() v.leftOver.Release()
v.Unlock()
} }
func (v *ReceivingWorker) ProcessSendingNext(number uint32) { func (v *ReceivingWorker) ProcessSendingNext(number uint32) {
@ -228,6 +234,19 @@ func (v *ReceivingWorker) Read(b []byte) int {
return total return total
} }
func (w *ReceivingWorker) IsDataAvailable() bool {
w.RLock()
defer w.RUnlock()
return w.window.HasFirst()
}
func (w *ReceivingWorker) NextNumber() uint32 {
w.RLock()
defer w.RUnlock()
return w.nextNumber
}
func (v *ReceivingWorker) Flush(current uint32) { func (v *ReceivingWorker) Flush(current uint32) {
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
@ -250,5 +269,8 @@ func (v *ReceivingWorker) CloseRead() {
} }
func (v *ReceivingWorker) UpdateNecessary() bool { func (v *ReceivingWorker) UpdateNecessary() bool {
v.RLock()
defer v.RUnlock()
return len(v.acklist.numbers) > 0 return len(v.acklist.numbers) > 0
} }

View File

@ -207,7 +207,9 @@ func NewSendingWorker(kcp *Connection) *SendingWorker {
} }
func (v *SendingWorker) Release() { func (v *SendingWorker) Release() {
v.Lock()
v.window.Release() v.window.Release()
v.Unlock()
} }
func (v *SendingWorker) ProcessReceivingNext(nextNumber uint32) { func (v *SendingWorker) ProcessReceivingNext(nextNumber uint32) {
@ -336,7 +338,6 @@ func (v *SendingWorker) OnPacketLoss(lossRate uint32) {
func (v *SendingWorker) Flush(current uint32) { func (v *SendingWorker) Flush(current uint32) {
v.Lock() v.Lock()
defer v.Unlock()
cwnd := v.firstUnacknowledged + v.conn.Config.GetSendingInFlightSize() cwnd := v.firstUnacknowledged + v.conn.Config.GetSendingInFlightSize()
if cwnd > v.remoteNextNumber { if cwnd > v.remoteNextNumber {
@ -348,11 +349,17 @@ func (v *SendingWorker) Flush(current uint32) {
if !v.window.IsEmpty() { if !v.window.IsEmpty() {
v.window.Flush(current, v.conn.roundTrip.Timeout(), cwnd) v.window.Flush(current, v.conn.roundTrip.Timeout(), cwnd)
} else if v.firstUnacknowledgedUpdated { v.firstUnacknowledgedUpdated = false
v.conn.Ping(current, CommandPing)
} }
updated := v.firstUnacknowledgedUpdated
v.firstUnacknowledgedUpdated = false v.firstUnacknowledgedUpdated = false
v.Unlock()
if updated {
v.conn.Ping(current, CommandPing)
}
} }
func (v *SendingWorker) CloseWrite() { func (v *SendingWorker) CloseWrite() {
@ -372,3 +379,10 @@ func (v *SendingWorker) IsEmpty() bool {
func (v *SendingWorker) UpdateNecessary() bool { func (v *SendingWorker) UpdateNecessary() bool {
return !v.IsEmpty() return !v.IsEmpty()
} }
func (w *SendingWorker) FirstUnacknowledged() uint32 {
w.RLock()
defer w.RUnlock()
return w.firstUnacknowledged
}