migrate to signal.Semaphore and Notifier

pull/861/head
Darien Raymond 2017-12-27 21:33:42 +01:00
parent 664b840812
commit 8a09c6c926
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
3 changed files with 215 additions and 224 deletions

22
common/signal/notifier.go Normal file
View File

@ -0,0 +1,22 @@
package signal
type Notifier struct {
c chan bool
}
func NewNotifier() *Notifier {
return &Notifier{
c: make(chan bool, 1),
}
}
func (n *Notifier) Signal() {
select {
case n.c <- true:
default:
}
}
func (n *Notifier) Wait() <-chan bool {
return n.c
}

View File

@ -9,6 +9,7 @@ import (
"v2ray.com/core/common/buf"
"v2ray.com/core/common/predicate"
"v2ray.com/core/common/signal"
)
var (
@ -120,7 +121,7 @@ type Updater struct {
shouldContinue predicate.Predicate
shouldTerminate predicate.Predicate
updateFunc func()
notifier chan bool
notifier *signal.Semaphore
}
func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater {
@ -129,31 +130,31 @@ func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTermi
shouldContinue: shouldContinue,
shouldTerminate: shouldTerminate,
updateFunc: updateFunc,
notifier: make(chan bool, 1),
notifier: signal.NewSemaphore(1),
}
go u.Run()
return u
}
func (u *Updater) WakeUp() {
select {
case u.notifier <- true:
case <-u.notifier.Wait():
go u.run()
default:
}
}
func (u *Updater) Run() {
for <-u.notifier {
if u.shouldTerminate() {
return
}
ticker := time.NewTicker(u.Interval())
for u.shouldContinue() {
u.updateFunc()
<-ticker.C
}
ticker.Stop()
func (u *Updater) run() {
defer u.notifier.Signal()
if u.shouldTerminate() {
return
}
ticker := time.NewTicker(u.Interval())
for u.shouldContinue() {
u.updateFunc()
<-ticker.C
}
ticker.Stop()
}
func (u *Updater) Interval() time.Duration {
@ -177,8 +178,8 @@ type Connection struct {
rd time.Time
wd time.Time // write deadline
since int64
dataInput chan bool
dataOutput chan bool
dataInput *signal.Notifier
dataOutput *signal.Notifier
Config *Config
state State
@ -206,8 +207,8 @@ func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, con
meta: meta,
closer: closer,
since: nowMillisec(),
dataInput: make(chan bool, 1),
dataOutput: make(chan bool, 1),
dataInput: signal.NewNotifier(),
dataOutput: signal.NewNotifier(),
Config: config,
output: NewRetryableWriter(NewSegmentWriter(writer)),
mss: config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead,
@ -241,66 +242,52 @@ func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, con
return conn
}
func (v *Connection) Elapsed() uint32 {
return uint32(nowMillisec() - v.since)
}
func (v *Connection) OnDataInput() {
select {
case v.dataInput <- true:
default:
}
}
func (v *Connection) OnDataOutput() {
select {
case v.dataOutput <- true:
default:
}
func (c *Connection) Elapsed() uint32 {
return uint32(nowMillisec() - c.since)
}
// ReadMultiBuffer implements buf.Reader.
func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
if v == nil {
func (c *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
if c == nil {
return nil, io.EOF
}
for {
if v.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
return nil, io.EOF
}
mb := v.receivingWorker.ReadMultiBuffer()
mb := c.receivingWorker.ReadMultiBuffer()
if !mb.IsEmpty() {
return mb, nil
}
if v.State() == StatePeerTerminating {
if c.State() == StatePeerTerminating {
return nil, io.EOF
}
if err := v.waitForDataInput(); err != nil {
if err := c.waitForDataInput(); err != nil {
return nil, err
}
}
}
func (v *Connection) waitForDataInput() error {
if v.State() == StatePeerTerminating {
func (c *Connection) waitForDataInput() error {
if c.State() == StatePeerTerminating {
return io.EOF
}
duration := time.Minute
if !v.rd.IsZero() {
duration = time.Until(v.rd)
if !c.rd.IsZero() {
duration = time.Until(c.rd)
if duration < 0 {
return ErrIOTimeout
}
}
select {
case <-v.dataInput:
case <-c.dataInput.Wait():
case <-time.After(duration):
if !v.rd.IsZero() && v.rd.Before(time.Now()) {
if !c.rd.IsZero() && c.rd.Before(time.Now()) {
return ErrIOTimeout
}
}
@ -309,39 +296,39 @@ func (v *Connection) waitForDataInput() error {
}
// Read implements the Conn Read method.
func (v *Connection) Read(b []byte) (int, error) {
if v == nil {
func (c *Connection) Read(b []byte) (int, error) {
if c == nil {
return 0, io.EOF
}
for {
if v.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
return 0, io.EOF
}
nBytes := v.receivingWorker.Read(b)
nBytes := c.receivingWorker.Read(b)
if nBytes > 0 {
return nBytes, nil
}
if err := v.waitForDataInput(); err != nil {
if err := c.waitForDataInput(); err != nil {
return 0, err
}
}
}
func (v *Connection) waitForDataOutput() error {
func (c *Connection) waitForDataOutput() error {
duration := time.Minute
if !v.wd.IsZero() {
duration = time.Until(v.wd)
if !c.wd.IsZero() {
duration = time.Until(c.wd)
if duration < 0 {
return ErrIOTimeout
}
}
select {
case <-v.dataOutput:
case <-c.dataOutput.Wait():
case <-time.After(duration):
if !v.wd.IsZero() && v.wd.Before(time.Now()) {
if !c.wd.IsZero() && c.wd.Before(time.Now()) {
return ErrIOTimeout
}
}
@ -350,295 +337,290 @@ func (v *Connection) waitForDataOutput() error {
}
// Write implements io.Writer.
func (v *Connection) Write(b []byte) (int, error) {
func (c *Connection) Write(b []byte) (int, error) {
totalWritten := 0
for {
if v == nil || v.State() != StateActive {
if c == nil || c.State() != StateActive {
return totalWritten, io.ErrClosedPipe
}
for v.sendingWorker.Push(func(bb []byte) (int, error) {
n := copy(bb[:v.mss], b[totalWritten:])
for c.sendingWorker.Push(func(bb []byte) (int, error) {
n := copy(bb[:c.mss], b[totalWritten:])
totalWritten += n
return n, nil
}) {
v.dataUpdater.WakeUp()
c.dataUpdater.WakeUp()
if totalWritten == len(b) {
return totalWritten, nil
}
}
if err := v.waitForDataOutput(); err != nil {
if err := c.waitForDataOutput(); err != nil {
return totalWritten, err
}
}
}
// WriteMultiBuffer implements buf.Writer.
func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release()
for {
if v == nil || v.State() != StateActive {
if c == nil || c.State() != StateActive {
return io.ErrClosedPipe
}
for v.sendingWorker.Push(func(bb []byte) (int, error) {
return mb.Read(bb[:v.mss])
for c.sendingWorker.Push(func(bb []byte) (int, error) {
return mb.Read(bb[:c.mss])
}) {
v.dataUpdater.WakeUp()
c.dataUpdater.WakeUp()
if mb.IsEmpty() {
return nil
}
}
if err := v.waitForDataOutput(); err != nil {
if err := c.waitForDataOutput(); err != nil {
return err
}
}
}
func (v *Connection) SetState(state State) {
current := v.Elapsed()
atomic.StoreInt32((*int32)(&v.state), int32(state))
atomic.StoreUint32(&v.stateBeginTime, current)
newError("#", v.meta.Conversation, " entering state ", state, " at ", current).AtDebug().WriteToLog()
func (c *Connection) SetState(state State) {
current := c.Elapsed()
atomic.StoreInt32((*int32)(&c.state), int32(state))
atomic.StoreUint32(&c.stateBeginTime, current)
newError("#", c.meta.Conversation, " entering state ", state, " at ", current).AtDebug().WriteToLog()
switch state {
case StateReadyToClose:
v.receivingWorker.CloseRead()
c.receivingWorker.CloseRead()
case StatePeerClosed:
v.sendingWorker.CloseWrite()
c.sendingWorker.CloseWrite()
case StateTerminating:
v.receivingWorker.CloseRead()
v.sendingWorker.CloseWrite()
v.pingUpdater.SetInterval(time.Second)
c.receivingWorker.CloseRead()
c.sendingWorker.CloseWrite()
c.pingUpdater.SetInterval(time.Second)
case StatePeerTerminating:
v.sendingWorker.CloseWrite()
v.pingUpdater.SetInterval(time.Second)
c.sendingWorker.CloseWrite()
c.pingUpdater.SetInterval(time.Second)
case StateTerminated:
v.receivingWorker.CloseRead()
v.sendingWorker.CloseWrite()
v.pingUpdater.SetInterval(time.Second)
v.dataUpdater.WakeUp()
v.pingUpdater.WakeUp()
go v.Terminate()
c.receivingWorker.CloseRead()
c.sendingWorker.CloseWrite()
c.pingUpdater.SetInterval(time.Second)
c.dataUpdater.WakeUp()
c.pingUpdater.WakeUp()
go c.Terminate()
}
}
// Close closes the connection.
func (v *Connection) Close() error {
if v == nil {
func (c *Connection) Close() error {
if c == nil {
return ErrClosedConnection
}
v.OnDataInput()
v.OnDataOutput()
c.dataInput.Signal()
c.dataOutput.Signal()
state := v.State()
if state.Is(StateReadyToClose, StateTerminating, StateTerminated) {
switch c.State() {
case StateReadyToClose, StateTerminating, StateTerminated:
return ErrClosedConnection
case StateActive:
c.SetState(StateReadyToClose)
case StatePeerClosed:
c.SetState(StateTerminating)
case StatePeerTerminating:
c.SetState(StateTerminated)
}
newError("closing connection to ", v.meta.RemoteAddr).WriteToLog()
if state == StateActive {
v.SetState(StateReadyToClose)
}
if state == StatePeerClosed {
v.SetState(StateTerminating)
}
if state == StatePeerTerminating {
v.SetState(StateTerminated)
}
newError("closing connection to ", c.meta.RemoteAddr).WriteToLog()
return nil
}
// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
func (v *Connection) LocalAddr() net.Addr {
if v == nil {
func (c *Connection) LocalAddr() net.Addr {
if c == nil {
return nil
}
return v.meta.LocalAddr
return c.meta.LocalAddr
}
// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
func (v *Connection) RemoteAddr() net.Addr {
if v == nil {
func (c *Connection) RemoteAddr() net.Addr {
if c == nil {
return nil
}
return v.meta.RemoteAddr
return c.meta.RemoteAddr
}
// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
func (v *Connection) SetDeadline(t time.Time) error {
if err := v.SetReadDeadline(t); err != nil {
func (c *Connection) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}
if err := v.SetWriteDeadline(t); err != nil {
if err := c.SetWriteDeadline(t); err != nil {
return err
}
return nil
}
// SetReadDeadline implements the Conn SetReadDeadline method.
func (v *Connection) SetReadDeadline(t time.Time) error {
if v == nil || v.State() != StateActive {
func (c *Connection) SetReadDeadline(t time.Time) error {
if c == nil || c.State() != StateActive {
return ErrClosedConnection
}
v.rd = t
c.rd = t
return nil
}
// SetWriteDeadline implements the Conn SetWriteDeadline method.
func (v *Connection) SetWriteDeadline(t time.Time) error {
if v == nil || v.State() != StateActive {
func (c *Connection) SetWriteDeadline(t time.Time) error {
if c == nil || c.State() != StateActive {
return ErrClosedConnection
}
v.wd = t
c.wd = t
return nil
}
// kcp update, input loop
func (v *Connection) updateTask() {
v.flush()
func (c *Connection) updateTask() {
c.flush()
}
func (v *Connection) Terminate() {
if v == nil {
func (c *Connection) Terminate() {
if c == nil {
return
}
newError("terminating connection to ", v.RemoteAddr()).WriteToLog()
newError("terminating connection to ", c.RemoteAddr()).WriteToLog()
//v.SetState(StateTerminated)
v.OnDataInput()
v.OnDataOutput()
c.dataInput.Signal()
c.dataOutput.Signal()
v.closer.Close()
v.sendingWorker.Release()
v.receivingWorker.Release()
c.closer.Close()
c.sendingWorker.Release()
c.receivingWorker.Release()
}
func (v *Connection) HandleOption(opt SegmentOption) {
func (c *Connection) HandleOption(opt SegmentOption) {
if (opt & SegmentOptionClose) == SegmentOptionClose {
v.OnPeerClosed()
c.OnPeerClosed()
}
}
func (v *Connection) OnPeerClosed() {
state := v.State()
if state == StateReadyToClose {
v.SetState(StateTerminating)
}
if state == StateActive {
v.SetState(StatePeerClosed)
func (c *Connection) OnPeerClosed() {
switch c.State() {
case StateReadyToClose:
c.SetState(StateTerminating)
case StateActive:
c.SetState(StatePeerClosed)
}
}
// Input when you received a low level packet (eg. UDP packet), call it
func (v *Connection) Input(segments []Segment) {
current := v.Elapsed()
atomic.StoreUint32(&v.lastIncomingTime, current)
func (c *Connection) Input(segments []Segment) {
current := c.Elapsed()
atomic.StoreUint32(&c.lastIncomingTime, current)
for _, seg := range segments {
if seg.Conversation() != v.meta.Conversation {
if seg.Conversation() != c.meta.Conversation {
break
}
switch seg := seg.(type) {
case *DataSegment:
v.HandleOption(seg.Option)
v.receivingWorker.ProcessSegment(seg)
if v.receivingWorker.IsDataAvailable() {
v.OnDataInput()
c.HandleOption(seg.Option)
c.receivingWorker.ProcessSegment(seg)
if c.receivingWorker.IsDataAvailable() {
c.dataInput.Signal()
}
v.dataUpdater.WakeUp()
c.dataUpdater.WakeUp()
case *AckSegment:
v.HandleOption(seg.Option)
v.sendingWorker.ProcessSegment(current, seg, v.roundTrip.Timeout())
v.OnDataOutput()
v.dataUpdater.WakeUp()
c.HandleOption(seg.Option)
c.sendingWorker.ProcessSegment(current, seg, c.roundTrip.Timeout())
c.dataOutput.Signal()
c.dataUpdater.WakeUp()
case *CmdOnlySegment:
v.HandleOption(seg.Option)
c.HandleOption(seg.Option)
if seg.Command() == CommandTerminate {
state := v.State()
if state == StateActive ||
state == StatePeerClosed {
v.SetState(StatePeerTerminating)
} else if state == StateReadyToClose {
v.SetState(StateTerminating)
} else if state == StateTerminating {
v.SetState(StateTerminated)
switch c.State() {
case StateActive, StatePeerClosed:
c.SetState(StatePeerTerminating)
case StateReadyToClose:
c.SetState(StateTerminating)
case StateTerminating:
c.SetState(StateTerminated)
}
}
if seg.Option == SegmentOptionClose || seg.Command() == CommandTerminate {
v.OnDataInput()
v.OnDataOutput()
c.dataInput.Signal()
c.dataOutput.Signal()
}
v.sendingWorker.ProcessReceivingNext(seg.ReceivinNext)
v.receivingWorker.ProcessSendingNext(seg.SendingNext)
v.roundTrip.UpdatePeerRTO(seg.PeerRTO, current)
c.sendingWorker.ProcessReceivingNext(seg.ReceivinNext)
c.receivingWorker.ProcessSendingNext(seg.SendingNext)
c.roundTrip.UpdatePeerRTO(seg.PeerRTO, current)
seg.Release()
default:
}
}
}
func (v *Connection) flush() {
current := v.Elapsed()
func (c *Connection) flush() {
current := c.Elapsed()
if v.State() == StateTerminated {
if c.State() == StateTerminated {
return
}
if v.State() == StateActive && current-atomic.LoadUint32(&v.lastIncomingTime) >= 30000 {
v.Close()
if c.State() == StateActive && current-atomic.LoadUint32(&c.lastIncomingTime) >= 30000 {
c.Close()
}
if v.State() == StateReadyToClose && v.sendingWorker.IsEmpty() {
v.SetState(StateTerminating)
if c.State() == StateReadyToClose && c.sendingWorker.IsEmpty() {
c.SetState(StateTerminating)
}
if v.State() == StateTerminating {
newError("#", v.meta.Conversation, " sending terminating cmd.").AtDebug().WriteToLog()
v.Ping(current, CommandTerminate)
if c.State() == StateTerminating {
newError("#", c.meta.Conversation, " sending terminating cmd.").AtDebug().WriteToLog()
c.Ping(current, CommandTerminate)
if current-atomic.LoadUint32(&v.stateBeginTime) > 8000 {
v.SetState(StateTerminated)
if current-atomic.LoadUint32(&c.stateBeginTime) > 8000 {
c.SetState(StateTerminated)
}
return
}
if v.State() == StatePeerTerminating && current-atomic.LoadUint32(&v.stateBeginTime) > 4000 {
v.SetState(StateTerminating)
if c.State() == StatePeerTerminating && current-atomic.LoadUint32(&c.stateBeginTime) > 4000 {
c.SetState(StateTerminating)
}
if v.State() == StateReadyToClose && current-atomic.LoadUint32(&v.stateBeginTime) > 15000 {
v.SetState(StateTerminating)
if c.State() == StateReadyToClose && current-atomic.LoadUint32(&c.stateBeginTime) > 15000 {
c.SetState(StateTerminating)
}
// flush acknowledges
v.receivingWorker.Flush(current)
v.sendingWorker.Flush(current)
c.receivingWorker.Flush(current)
c.sendingWorker.Flush(current)
if current-atomic.LoadUint32(&v.lastPingTime) >= 3000 {
v.Ping(current, CommandPing)
if current-atomic.LoadUint32(&c.lastPingTime) >= 3000 {
c.Ping(current, CommandPing)
}
}
func (v *Connection) State() State {
return State(atomic.LoadInt32((*int32)(&v.state)))
func (c *Connection) State() State {
return State(atomic.LoadInt32((*int32)(&c.state)))
}
func (v *Connection) Ping(current uint32, cmd Command) {
func (c *Connection) Ping(current uint32, cmd Command) {
seg := NewCmdOnlySegment()
seg.Conv = v.meta.Conversation
seg.Conv = c.meta.Conversation
seg.Cmd = cmd
seg.ReceivinNext = v.receivingWorker.NextNumber()
seg.SendingNext = v.sendingWorker.FirstUnacknowledged()
seg.PeerRTO = v.roundTrip.Timeout()
if v.State() == StateReadyToClose {
seg.ReceivinNext = c.receivingWorker.NextNumber()
seg.SendingNext = c.sendingWorker.FirstUnacknowledged()
seg.PeerRTO = c.roundTrip.Timeout()
if c.State() == StateReadyToClose {
seg.Option = SegmentOptionClose
}
v.output.Write(seg)
atomic.StoreUint32(&v.lastPingTime, current)
c.output.Write(seg)
atomic.StoreUint32(&c.lastPingTime, current)
seg.Release()
}

View File

@ -9,6 +9,7 @@ import (
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/platform"
"v2ray.com/core/common/signal"
)
// NewRay creates a new Ray for direct traffic transport.
@ -57,8 +58,8 @@ type Stream struct {
data buf.MultiBuffer
size uint64
ctx context.Context
readSignal chan bool
writeSignal chan bool
readSignal *signal.Notifier
writeSignal *signal.Notifier
close bool
err bool
}
@ -67,8 +68,8 @@ type Stream struct {
func NewStream(ctx context.Context) *Stream {
return &Stream{
ctx: ctx,
readSignal: make(chan bool, 1),
writeSignal: make(chan bool, 1),
readSignal: signal.NewNotifier(),
writeSignal: signal.NewNotifier(),
size: 0,
}
}
@ -105,7 +106,7 @@ func (s *Stream) Peek(b *buf.Buffer) {
}))
}
// Read reads data from the Stream.
// ReadMultiBuffer reads data from the Stream.
func (s *Stream) ReadMultiBuffer() (buf.MultiBuffer, error) {
for {
mb, err := s.getData()
@ -114,14 +115,14 @@ func (s *Stream) ReadMultiBuffer() (buf.MultiBuffer, error) {
}
if mb != nil {
s.notifyRead()
s.readSignal.Signal()
return mb, nil
}
select {
case <-s.ctx.Done():
return nil, io.EOF
case <-s.writeSignal:
case <-s.writeSignal.Wait():
}
}
}
@ -135,7 +136,7 @@ func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
}
if mb != nil {
s.notifyRead()
s.readSignal.Signal()
return mb, nil
}
@ -144,7 +145,7 @@ func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
return nil, io.EOF
case <-time.After(timeout):
return nil, buf.ErrReadTimeout
case <-s.writeSignal:
case <-s.writeSignal.Wait():
}
}
}
@ -167,7 +168,7 @@ func (s *Stream) waitForStreamSize() error {
select {
case <-s.ctx.Done():
return io.ErrClosedPipe
case <-s.readSignal:
case <-s.readSignal.Wait():
if s.err || s.close {
return io.ErrClosedPipe
}
@ -177,7 +178,7 @@ func (s *Stream) waitForStreamSize() error {
return nil
}
// Write writes more data into the Stream.
// WriteMultiBuffer writes more data into the Stream.
func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error {
if data.IsEmpty() {
return nil
@ -202,31 +203,17 @@ func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error {
s.data.AppendMulti(data)
s.size += uint64(data.Len())
s.notifyWrite()
s.writeSignal.Signal()
return nil
}
func (s *Stream) notifyRead() {
select {
case s.readSignal <- true:
default:
}
}
func (s *Stream) notifyWrite() {
select {
case s.writeSignal <- true:
default:
}
}
// Close closes the stream for writing. Read() still works until EOF.
func (s *Stream) Close() {
s.access.Lock()
s.close = true
s.notifyRead()
s.notifyWrite()
s.readSignal.Signal()
s.writeSignal.Signal()
s.access.Unlock()
}
@ -239,7 +226,7 @@ func (s *Stream) CloseError() {
s.data = nil
s.size = 0
}
s.notifyRead()
s.notifyWrite()
s.readSignal.Signal()
s.writeSignal.Signal()
s.access.Unlock()
}