fix race conditions in kcp

pull/432/head
Darien Raymond 2017-02-18 00:04:25 +01:00
parent 587ada599c
commit ebed271a92
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
4 changed files with 77 additions and 28 deletions

View File

@ -116,7 +116,7 @@ func (v *RoundTripInfo) SmoothedTime() uint32 {
}
type Updater struct {
interval time.Duration
interval int64
shouldContinue predicate.Predicate
shouldTerminate predicate.Predicate
updateFunc func()
@ -125,7 +125,7 @@ type Updater struct {
func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater {
u := &Updater{
interval: time.Duration(interval) * time.Millisecond,
interval: int64(time.Duration(interval) * time.Millisecond),
shouldContinue: shouldContinue,
shouldTerminate: shouldTerminate,
updateFunc: updateFunc,
@ -149,11 +149,19 @@ func (v *Updater) Run() {
}
for v.shouldContinue() {
v.updateFunc()
time.Sleep(v.interval)
time.Sleep(v.Interval())
}
}
}
func (u *Updater) Interval() time.Duration {
return time.Duration(atomic.LoadInt64(&u.interval))
}
func (u *Updater) SetInterval(d time.Duration) {
atomic.StoreInt64(&u.interval, int64(d))
}
type SystemConnection interface {
net.Conn
Id() internal.ConnectionID
@ -342,14 +350,14 @@ func (v *Connection) SetState(state State) {
case StateTerminating:
v.receivingWorker.CloseRead()
v.sendingWorker.CloseWrite()
v.pingUpdater.interval = time.Second
v.pingUpdater.SetInterval(time.Second)
case StatePeerTerminating:
v.sendingWorker.CloseWrite()
v.pingUpdater.interval = time.Second
v.pingUpdater.SetInterval(time.Second)
case StateTerminated:
v.receivingWorker.CloseRead()
v.sendingWorker.CloseWrite()
v.pingUpdater.interval = time.Second
v.pingUpdater.SetInterval(time.Second)
v.dataUpdater.WakeUp()
v.pingUpdater.WakeUp()
go v.Terminate()
@ -491,7 +499,7 @@ func (v *Connection) Input(segments []Segment) {
case *DataSegment:
v.HandleOption(seg.Option)
v.receivingWorker.ProcessSegment(seg)
if seg.Number == v.receivingWorker.nextNumber {
if v.receivingWorker.IsDataAvailable() {
v.OnDataInput()
}
v.dataUpdater.WakeUp()
@ -573,8 +581,8 @@ func (v *Connection) Ping(current uint32, cmd Command) {
seg := NewCmdOnlySegment()
seg.Conv = v.conv
seg.Cmd = cmd
seg.ReceivinNext = v.receivingWorker.nextNumber
seg.SendingNext = v.sendingWorker.firstUnacknowledged
seg.ReceivinNext = v.receivingWorker.NextNumber()
seg.SendingNext = v.sendingWorker.FirstUnacknowledged()
seg.PeerRTO = v.roundTrip.Timeout()
if v.State() == StateReadyToClose {
seg.Option = SegmentOptionClose

View File

@ -79,7 +79,7 @@ func (o *ServerConnection) Id() internal.ConnectionID {
// Listener defines a server listening for connections
type Listener struct {
sync.Mutex
running bool
closed chan bool
sessions map[ConnectionID]*Connection
awaitingConns chan *Connection
hub *udp.Hub
@ -116,7 +116,7 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
},
sessions: make(map[ConnectionID]*Connection),
awaitingConns: make(chan *Connection, 64),
running: true,
closed: make(chan bool),
config: kcpSettings,
}
if options.Stream != nil && options.Stream.HasSecuritySettings() {
@ -134,7 +134,9 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
if err != nil {
return nil, err
}
l.Lock()
l.hub = hub
l.Unlock()
log.Info("KCP|Listener: listening on ", address, ":", port)
return l, nil
}
@ -148,12 +150,15 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina
return
}
if !v.running {
select {
case <-v.closed:
return
default:
}
v.Lock()
defer v.Unlock()
if !v.running {
if v.hub == nil {
return
}
if payload.Len() < 4 {
@ -208,24 +213,22 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina
}
func (v *Listener) Remove(id ConnectionID) {
if !v.running {
select {
case <-v.closed:
return
default:
v.Lock()
delete(v.sessions, id)
v.Unlock()
}
v.Lock()
defer v.Unlock()
if !v.running {
return
}
delete(v.sessions, id)
}
// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn.
func (v *Listener) Accept() (internet.Connection, error) {
for {
if !v.running {
return nil, ErrClosedListener
}
select {
case <-v.closed:
return nil, ErrClosedListener
case conn, open := <-v.awaitingConns:
if !open {
break
@ -243,13 +246,15 @@ func (v *Listener) Accept() (internet.Connection, error) {
// Close stops listening on the UDP address. Already Accepted connections are not closed.
func (v *Listener) Close() error {
if !v.running {
select {
case <-v.closed:
return ErrClosedListener
default:
}
v.Lock()
defer v.Unlock()
v.running = false
close(v.closed)
close(v.awaitingConns)
for _, conn := range v.sessions {
go conn.Terminate()

View File

@ -48,6 +48,10 @@ func (v *ReceivingWindow) RemoveFirst() *DataSegment {
return v.Remove(0)
}
func (w *ReceivingWindow) HasFirst() bool {
return w.list[w.Position(0)] != nil
}
func (v *ReceivingWindow) Advance() {
v.start++
if v.start == v.size {
@ -163,7 +167,9 @@ func NewReceivingWorker(kcp *Connection) *ReceivingWorker {
}
func (v *ReceivingWorker) Release() {
v.Lock()
v.leftOver.Release()
v.Unlock()
}
func (v *ReceivingWorker) ProcessSendingNext(number uint32) {
@ -228,6 +234,19 @@ func (v *ReceivingWorker) Read(b []byte) int {
return total
}
func (w *ReceivingWorker) IsDataAvailable() bool {
w.RLock()
defer w.RUnlock()
return w.window.HasFirst()
}
func (w *ReceivingWorker) NextNumber() uint32 {
w.RLock()
defer w.RUnlock()
return w.nextNumber
}
func (v *ReceivingWorker) Flush(current uint32) {
v.Lock()
defer v.Unlock()
@ -250,5 +269,8 @@ func (v *ReceivingWorker) CloseRead() {
}
func (v *ReceivingWorker) UpdateNecessary() bool {
v.RLock()
defer v.RUnlock()
return len(v.acklist.numbers) > 0
}

View File

@ -207,7 +207,9 @@ func NewSendingWorker(kcp *Connection) *SendingWorker {
}
func (v *SendingWorker) Release() {
v.Lock()
v.window.Release()
v.Unlock()
}
func (v *SendingWorker) ProcessReceivingNext(nextNumber uint32) {
@ -336,7 +338,6 @@ func (v *SendingWorker) OnPacketLoss(lossRate uint32) {
func (v *SendingWorker) Flush(current uint32) {
v.Lock()
defer v.Unlock()
cwnd := v.firstUnacknowledged + v.conn.Config.GetSendingInFlightSize()
if cwnd > v.remoteNextNumber {
@ -348,11 +349,17 @@ func (v *SendingWorker) Flush(current uint32) {
if !v.window.IsEmpty() {
v.window.Flush(current, v.conn.roundTrip.Timeout(), cwnd)
} else if v.firstUnacknowledgedUpdated {
v.conn.Ping(current, CommandPing)
v.firstUnacknowledgedUpdated = false
}
updated := v.firstUnacknowledgedUpdated
v.firstUnacknowledgedUpdated = false
v.Unlock()
if updated {
v.conn.Ping(current, CommandPing)
}
}
func (v *SendingWorker) CloseWrite() {
@ -372,3 +379,10 @@ func (v *SendingWorker) IsEmpty() bool {
func (v *SendingWorker) UpdateNecessary() bool {
return !v.IsEmpty()
}
func (w *SendingWorker) FirstUnacknowledged() uint32 {
w.RLock()
defer w.RUnlock()
return w.firstUnacknowledged
}