diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 333005e3..53154f67 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -168,15 +168,15 @@ type SystemConnection interface { // Connection is a KCP connection over UDP. type Connection struct { - conn SystemConnection - connRecycler internal.ConnectionRecyler - block internet.Authenticator - rd time.Time - wd time.Time // write deadline - since int64 - dataInputCond *sync.Cond - dataOutputCond *sync.Cond - Config *Config + conn SystemConnection + connRecycler internal.ConnectionRecyler + block internet.Authenticator + rd time.Time + wd time.Time // write deadline + since int64 + dataInput chan bool + dataOutput chan bool + Config *Config conv uint16 state State @@ -203,15 +203,15 @@ func NewConnection(conv uint16, sysConn SystemConnection, recycler internal.Conn log.Info("KCP|Connection: creating connection ", conv) conn := &Connection{ - conv: conv, - conn: sysConn, - connRecycler: recycler, - since: nowMillisec(), - dataInputCond: sync.NewCond(new(sync.Mutex)), - dataOutputCond: sync.NewCond(new(sync.Mutex)), - Config: config, - output: NewSegmentWriter(sysConn, config.GetMtu().GetValue()-uint32(sysConn.Overhead())), - mss: config.GetMtu().GetValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead, + conv: conv, + conn: sysConn, + connRecycler: recycler, + since: nowMillisec(), + dataInput: make(chan bool, 1), + dataOutput: make(chan bool, 1), + Config: config, + output: NewSegmentWriter(sysConn, config.GetMtu().GetValue()-uint32(sysConn.Overhead())), + mss: config.GetMtu().GetValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead, roundTrip: &RoundTripInfo{ rto: 100, minRtt: config.Tti.GetValue(), @@ -247,6 +247,20 @@ func (v *Connection) Elapsed() uint32 { return uint32(nowMillisec() - v.since) } +func (v *Connection) OnDataInput() { + select { + case v.dataInput <- true: + default: + } +} + +func (v *Connection) OnDataOutput() { + select { + case v.dataOutput <- true: + default: + } +} + // Read implements the Conn Read method. func (v *Connection) Read(b []byte) (int, error) { if v == nil { @@ -266,22 +280,20 @@ func (v *Connection) Read(b []byte) (int, error) { return 0, io.EOF } - var timer *time.Timer + duration := time.Duration(time.Minute) if !v.rd.IsZero() { - duration := v.rd.Sub(time.Now()) - if duration <= 0 { + duration = v.rd.Sub(time.Now()) + if duration < 0 { return 0, ErrIOTimeout } - timer = time.AfterFunc(duration, v.dataInputCond.Signal) - } - v.dataInputCond.L.Lock() - v.dataInputCond.Wait() - v.dataInputCond.L.Unlock() - if timer != nil { - timer.Stop() } - if !v.rd.IsZero() && v.rd.Before(time.Now()) { - return 0, ErrIOTimeout + + select { + case <-v.dataInput: + case <-time.After(duration): + if !v.rd.IsZero() && v.rd.Before(time.Now()) { + return 0, ErrIOTimeout + } } } } @@ -304,24 +316,20 @@ func (v *Connection) Write(b []byte) (int, error) { } } - var timer *time.Timer - if !v.wd.IsZero() { - duration := v.wd.Sub(time.Now()) - if duration <= 0 { + duration := time.Duration(time.Minute) + if !v.rd.IsZero() { + duration = v.wd.Sub(time.Now()) + if duration < 0 { return totalWritten, ErrIOTimeout } - timer = time.AfterFunc(duration, v.dataOutputCond.Signal) - } - v.dataOutputCond.L.Lock() - v.dataOutputCond.Wait() - v.dataOutputCond.L.Unlock() - - if timer != nil { - timer.Stop() } - if !v.wd.IsZero() && v.wd.Before(time.Now()) { - return totalWritten, ErrIOTimeout + select { + case <-v.dataOutput: + case <-time.After(duration): + if !v.wd.IsZero() && v.wd.Before(time.Now()) { + return totalWritten, ErrIOTimeout + } } } } @@ -360,8 +368,8 @@ func (v *Connection) Close() error { return ErrClosedConnection } - v.dataInputCond.Broadcast() - v.dataOutputCond.Broadcast() + v.OnDataInput() + v.OnDataOutput() state := v.State() if state.Is(StateReadyToClose, StateTerminating, StateTerminated) { @@ -447,8 +455,9 @@ func (v *Connection) Terminate() { log.Info("KCP|Connection: Terminating connection to ", v.RemoteAddr()) //v.SetState(StateTerminated) - v.dataInputCond.Broadcast() - v.dataOutputCond.Broadcast() + v.OnDataInput() + v.OnDataOutput() + if v.Config.ConnectionReuse.IsEnabled() && v.reusable { v.connRecycler.Put(v.conn.Id(), v.conn) } else { @@ -481,19 +490,21 @@ func (v *Connection) Input(segments []Segment) { for _, seg := range segments { if seg.Conversation() != v.conv { - return + break } switch seg := seg.(type) { case *DataSegment: v.HandleOption(seg.Option) v.receivingWorker.ProcessSegment(seg) - v.dataInputCond.Signal() + if seg.Number == v.receivingWorker.nextNumber { + v.OnDataInput() + } v.dataUpdater.WakeUp() case *AckSegment: v.HandleOption(seg.Option) v.sendingWorker.ProcessSegment(current, seg, v.roundTrip.Timeout()) - v.dataOutputCond.Signal() + v.OnDataOutput() v.dataUpdater.WakeUp() case *CmdOnlySegment: v.HandleOption(seg.Option)