migrate to signal.Semaphore and Notifier

pull/861/head
Darien Raymond 2017-12-27 21:33:42 +01:00
parent 664b840812
commit 8a09c6c926
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
3 changed files with 215 additions and 224 deletions

22
common/signal/notifier.go Normal file
View File

@ -0,0 +1,22 @@
package signal
type Notifier struct {
c chan bool
}
func NewNotifier() *Notifier {
return &Notifier{
c: make(chan bool, 1),
}
}
func (n *Notifier) Signal() {
select {
case n.c <- true:
default:
}
}
func (n *Notifier) Wait() <-chan bool {
return n.c
}

View File

@ -9,6 +9,7 @@ import (
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/predicate" "v2ray.com/core/common/predicate"
"v2ray.com/core/common/signal"
) )
var ( var (
@ -120,7 +121,7 @@ type Updater struct {
shouldContinue predicate.Predicate shouldContinue predicate.Predicate
shouldTerminate predicate.Predicate shouldTerminate predicate.Predicate
updateFunc func() updateFunc func()
notifier chan bool notifier *signal.Semaphore
} }
func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater { func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater {
@ -129,31 +130,31 @@ func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTermi
shouldContinue: shouldContinue, shouldContinue: shouldContinue,
shouldTerminate: shouldTerminate, shouldTerminate: shouldTerminate,
updateFunc: updateFunc, updateFunc: updateFunc,
notifier: make(chan bool, 1), notifier: signal.NewSemaphore(1),
} }
go u.Run()
return u return u
} }
func (u *Updater) WakeUp() { func (u *Updater) WakeUp() {
select { select {
case u.notifier <- true: case <-u.notifier.Wait():
go u.run()
default: default:
} }
} }
func (u *Updater) Run() { func (u *Updater) run() {
for <-u.notifier { defer u.notifier.Signal()
if u.shouldTerminate() {
return if u.shouldTerminate() {
} return
ticker := time.NewTicker(u.Interval())
for u.shouldContinue() {
u.updateFunc()
<-ticker.C
}
ticker.Stop()
} }
ticker := time.NewTicker(u.Interval())
for u.shouldContinue() {
u.updateFunc()
<-ticker.C
}
ticker.Stop()
} }
func (u *Updater) Interval() time.Duration { func (u *Updater) Interval() time.Duration {
@ -177,8 +178,8 @@ type Connection struct {
rd time.Time rd time.Time
wd time.Time // write deadline wd time.Time // write deadline
since int64 since int64
dataInput chan bool dataInput *signal.Notifier
dataOutput chan bool dataOutput *signal.Notifier
Config *Config Config *Config
state State state State
@ -206,8 +207,8 @@ func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, con
meta: meta, meta: meta,
closer: closer, closer: closer,
since: nowMillisec(), since: nowMillisec(),
dataInput: make(chan bool, 1), dataInput: signal.NewNotifier(),
dataOutput: make(chan bool, 1), dataOutput: signal.NewNotifier(),
Config: config, Config: config,
output: NewRetryableWriter(NewSegmentWriter(writer)), output: NewRetryableWriter(NewSegmentWriter(writer)),
mss: config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead, mss: config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead,
@ -241,66 +242,52 @@ func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, con
return conn return conn
} }
func (v *Connection) Elapsed() uint32 { func (c *Connection) Elapsed() uint32 {
return uint32(nowMillisec() - v.since) return uint32(nowMillisec() - c.since)
}
func (v *Connection) OnDataInput() {
select {
case v.dataInput <- true:
default:
}
}
func (v *Connection) OnDataOutput() {
select {
case v.dataOutput <- true:
default:
}
} }
// ReadMultiBuffer implements buf.Reader. // ReadMultiBuffer implements buf.Reader.
func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { func (c *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
if v == nil { if c == nil {
return nil, io.EOF return nil, io.EOF
} }
for { for {
if v.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
return nil, io.EOF return nil, io.EOF
} }
mb := v.receivingWorker.ReadMultiBuffer() mb := c.receivingWorker.ReadMultiBuffer()
if !mb.IsEmpty() { if !mb.IsEmpty() {
return mb, nil return mb, nil
} }
if v.State() == StatePeerTerminating { if c.State() == StatePeerTerminating {
return nil, io.EOF return nil, io.EOF
} }
if err := v.waitForDataInput(); err != nil { if err := c.waitForDataInput(); err != nil {
return nil, err return nil, err
} }
} }
} }
func (v *Connection) waitForDataInput() error { func (c *Connection) waitForDataInput() error {
if v.State() == StatePeerTerminating { if c.State() == StatePeerTerminating {
return io.EOF return io.EOF
} }
duration := time.Minute duration := time.Minute
if !v.rd.IsZero() { if !c.rd.IsZero() {
duration = time.Until(v.rd) duration = time.Until(c.rd)
if duration < 0 { if duration < 0 {
return ErrIOTimeout return ErrIOTimeout
} }
} }
select { select {
case <-v.dataInput: case <-c.dataInput.Wait():
case <-time.After(duration): case <-time.After(duration):
if !v.rd.IsZero() && v.rd.Before(time.Now()) { if !c.rd.IsZero() && c.rd.Before(time.Now()) {
return ErrIOTimeout return ErrIOTimeout
} }
} }
@ -309,39 +296,39 @@ func (v *Connection) waitForDataInput() error {
} }
// Read implements the Conn Read method. // Read implements the Conn Read method.
func (v *Connection) Read(b []byte) (int, error) { func (c *Connection) Read(b []byte) (int, error) {
if v == nil { if c == nil {
return 0, io.EOF return 0, io.EOF
} }
for { for {
if v.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
return 0, io.EOF return 0, io.EOF
} }
nBytes := v.receivingWorker.Read(b) nBytes := c.receivingWorker.Read(b)
if nBytes > 0 { if nBytes > 0 {
return nBytes, nil return nBytes, nil
} }
if err := v.waitForDataInput(); err != nil { if err := c.waitForDataInput(); err != nil {
return 0, err return 0, err
} }
} }
} }
func (v *Connection) waitForDataOutput() error { func (c *Connection) waitForDataOutput() error {
duration := time.Minute duration := time.Minute
if !v.wd.IsZero() { if !c.wd.IsZero() {
duration = time.Until(v.wd) duration = time.Until(c.wd)
if duration < 0 { if duration < 0 {
return ErrIOTimeout return ErrIOTimeout
} }
} }
select { select {
case <-v.dataOutput: case <-c.dataOutput.Wait():
case <-time.After(duration): case <-time.After(duration):
if !v.wd.IsZero() && v.wd.Before(time.Now()) { if !c.wd.IsZero() && c.wd.Before(time.Now()) {
return ErrIOTimeout return ErrIOTimeout
} }
} }
@ -350,295 +337,290 @@ func (v *Connection) waitForDataOutput() error {
} }
// Write implements io.Writer. // Write implements io.Writer.
func (v *Connection) Write(b []byte) (int, error) { func (c *Connection) Write(b []byte) (int, error) {
totalWritten := 0 totalWritten := 0
for { for {
if v == nil || v.State() != StateActive { if c == nil || c.State() != StateActive {
return totalWritten, io.ErrClosedPipe return totalWritten, io.ErrClosedPipe
} }
for v.sendingWorker.Push(func(bb []byte) (int, error) { for c.sendingWorker.Push(func(bb []byte) (int, error) {
n := copy(bb[:v.mss], b[totalWritten:]) n := copy(bb[:c.mss], b[totalWritten:])
totalWritten += n totalWritten += n
return n, nil return n, nil
}) { }) {
v.dataUpdater.WakeUp() c.dataUpdater.WakeUp()
if totalWritten == len(b) { if totalWritten == len(b) {
return totalWritten, nil return totalWritten, nil
} }
} }
if err := v.waitForDataOutput(); err != nil { if err := c.waitForDataOutput(); err != nil {
return totalWritten, err return totalWritten, err
} }
} }
} }
// WriteMultiBuffer implements buf.Writer. // WriteMultiBuffer implements buf.Writer.
func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release() defer mb.Release()
for { for {
if v == nil || v.State() != StateActive { if c == nil || c.State() != StateActive {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
for v.sendingWorker.Push(func(bb []byte) (int, error) { for c.sendingWorker.Push(func(bb []byte) (int, error) {
return mb.Read(bb[:v.mss]) return mb.Read(bb[:c.mss])
}) { }) {
v.dataUpdater.WakeUp() c.dataUpdater.WakeUp()
if mb.IsEmpty() { if mb.IsEmpty() {
return nil return nil
} }
} }
if err := v.waitForDataOutput(); err != nil { if err := c.waitForDataOutput(); err != nil {
return err return err
} }
} }
} }
func (v *Connection) SetState(state State) { func (c *Connection) SetState(state State) {
current := v.Elapsed() current := c.Elapsed()
atomic.StoreInt32((*int32)(&v.state), int32(state)) atomic.StoreInt32((*int32)(&c.state), int32(state))
atomic.StoreUint32(&v.stateBeginTime, current) atomic.StoreUint32(&c.stateBeginTime, current)
newError("#", v.meta.Conversation, " entering state ", state, " at ", current).AtDebug().WriteToLog() newError("#", c.meta.Conversation, " entering state ", state, " at ", current).AtDebug().WriteToLog()
switch state { switch state {
case StateReadyToClose: case StateReadyToClose:
v.receivingWorker.CloseRead() c.receivingWorker.CloseRead()
case StatePeerClosed: case StatePeerClosed:
v.sendingWorker.CloseWrite() c.sendingWorker.CloseWrite()
case StateTerminating: case StateTerminating:
v.receivingWorker.CloseRead() c.receivingWorker.CloseRead()
v.sendingWorker.CloseWrite() c.sendingWorker.CloseWrite()
v.pingUpdater.SetInterval(time.Second) c.pingUpdater.SetInterval(time.Second)
case StatePeerTerminating: case StatePeerTerminating:
v.sendingWorker.CloseWrite() c.sendingWorker.CloseWrite()
v.pingUpdater.SetInterval(time.Second) c.pingUpdater.SetInterval(time.Second)
case StateTerminated: case StateTerminated:
v.receivingWorker.CloseRead() c.receivingWorker.CloseRead()
v.sendingWorker.CloseWrite() c.sendingWorker.CloseWrite()
v.pingUpdater.SetInterval(time.Second) c.pingUpdater.SetInterval(time.Second)
v.dataUpdater.WakeUp() c.dataUpdater.WakeUp()
v.pingUpdater.WakeUp() c.pingUpdater.WakeUp()
go v.Terminate() go c.Terminate()
} }
} }
// Close closes the connection. // Close closes the connection.
func (v *Connection) Close() error { func (c *Connection) Close() error {
if v == nil { if c == nil {
return ErrClosedConnection return ErrClosedConnection
} }
v.OnDataInput() c.dataInput.Signal()
v.OnDataOutput() c.dataOutput.Signal()
state := v.State() switch c.State() {
if state.Is(StateReadyToClose, StateTerminating, StateTerminated) { case StateReadyToClose, StateTerminating, StateTerminated:
return ErrClosedConnection return ErrClosedConnection
case StateActive:
c.SetState(StateReadyToClose)
case StatePeerClosed:
c.SetState(StateTerminating)
case StatePeerTerminating:
c.SetState(StateTerminated)
} }
newError("closing connection to ", v.meta.RemoteAddr).WriteToLog()
if state == StateActive { newError("closing connection to ", c.meta.RemoteAddr).WriteToLog()
v.SetState(StateReadyToClose)
}
if state == StatePeerClosed {
v.SetState(StateTerminating)
}
if state == StatePeerTerminating {
v.SetState(StateTerminated)
}
return nil return nil
} }
// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
func (v *Connection) LocalAddr() net.Addr { func (c *Connection) LocalAddr() net.Addr {
if v == nil { if c == nil {
return nil return nil
} }
return v.meta.LocalAddr return c.meta.LocalAddr
} }
// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
func (v *Connection) RemoteAddr() net.Addr { func (c *Connection) RemoteAddr() net.Addr {
if v == nil { if c == nil {
return nil return nil
} }
return v.meta.RemoteAddr return c.meta.RemoteAddr
} }
// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
func (v *Connection) SetDeadline(t time.Time) error { func (c *Connection) SetDeadline(t time.Time) error {
if err := v.SetReadDeadline(t); err != nil { if err := c.SetReadDeadline(t); err != nil {
return err return err
} }
if err := v.SetWriteDeadline(t); err != nil { if err := c.SetWriteDeadline(t); err != nil {
return err return err
} }
return nil return nil
} }
// SetReadDeadline implements the Conn SetReadDeadline method. // SetReadDeadline implements the Conn SetReadDeadline method.
func (v *Connection) SetReadDeadline(t time.Time) error { func (c *Connection) SetReadDeadline(t time.Time) error {
if v == nil || v.State() != StateActive { if c == nil || c.State() != StateActive {
return ErrClosedConnection return ErrClosedConnection
} }
v.rd = t c.rd = t
return nil return nil
} }
// SetWriteDeadline implements the Conn SetWriteDeadline method. // SetWriteDeadline implements the Conn SetWriteDeadline method.
func (v *Connection) SetWriteDeadline(t time.Time) error { func (c *Connection) SetWriteDeadline(t time.Time) error {
if v == nil || v.State() != StateActive { if c == nil || c.State() != StateActive {
return ErrClosedConnection return ErrClosedConnection
} }
v.wd = t c.wd = t
return nil return nil
} }
// kcp update, input loop // kcp update, input loop
func (v *Connection) updateTask() { func (c *Connection) updateTask() {
v.flush() c.flush()
} }
func (v *Connection) Terminate() { func (c *Connection) Terminate() {
if v == nil { if c == nil {
return return
} }
newError("terminating connection to ", v.RemoteAddr()).WriteToLog() newError("terminating connection to ", c.RemoteAddr()).WriteToLog()
//v.SetState(StateTerminated) //v.SetState(StateTerminated)
v.OnDataInput() c.dataInput.Signal()
v.OnDataOutput() c.dataOutput.Signal()
v.closer.Close() c.closer.Close()
v.sendingWorker.Release() c.sendingWorker.Release()
v.receivingWorker.Release() c.receivingWorker.Release()
} }
func (v *Connection) HandleOption(opt SegmentOption) { func (c *Connection) HandleOption(opt SegmentOption) {
if (opt & SegmentOptionClose) == SegmentOptionClose { if (opt & SegmentOptionClose) == SegmentOptionClose {
v.OnPeerClosed() c.OnPeerClosed()
} }
} }
func (v *Connection) OnPeerClosed() { func (c *Connection) OnPeerClosed() {
state := v.State() switch c.State() {
if state == StateReadyToClose { case StateReadyToClose:
v.SetState(StateTerminating) c.SetState(StateTerminating)
} case StateActive:
if state == StateActive { c.SetState(StatePeerClosed)
v.SetState(StatePeerClosed)
} }
} }
// Input when you received a low level packet (eg. UDP packet), call it // Input when you received a low level packet (eg. UDP packet), call it
func (v *Connection) Input(segments []Segment) { func (c *Connection) Input(segments []Segment) {
current := v.Elapsed() current := c.Elapsed()
atomic.StoreUint32(&v.lastIncomingTime, current) atomic.StoreUint32(&c.lastIncomingTime, current)
for _, seg := range segments { for _, seg := range segments {
if seg.Conversation() != v.meta.Conversation { if seg.Conversation() != c.meta.Conversation {
break break
} }
switch seg := seg.(type) { switch seg := seg.(type) {
case *DataSegment: case *DataSegment:
v.HandleOption(seg.Option) c.HandleOption(seg.Option)
v.receivingWorker.ProcessSegment(seg) c.receivingWorker.ProcessSegment(seg)
if v.receivingWorker.IsDataAvailable() { if c.receivingWorker.IsDataAvailable() {
v.OnDataInput() c.dataInput.Signal()
} }
v.dataUpdater.WakeUp() c.dataUpdater.WakeUp()
case *AckSegment: case *AckSegment:
v.HandleOption(seg.Option) c.HandleOption(seg.Option)
v.sendingWorker.ProcessSegment(current, seg, v.roundTrip.Timeout()) c.sendingWorker.ProcessSegment(current, seg, c.roundTrip.Timeout())
v.OnDataOutput() c.dataOutput.Signal()
v.dataUpdater.WakeUp() c.dataUpdater.WakeUp()
case *CmdOnlySegment: case *CmdOnlySegment:
v.HandleOption(seg.Option) c.HandleOption(seg.Option)
if seg.Command() == CommandTerminate { if seg.Command() == CommandTerminate {
state := v.State() switch c.State() {
if state == StateActive || case StateActive, StatePeerClosed:
state == StatePeerClosed { c.SetState(StatePeerTerminating)
v.SetState(StatePeerTerminating) case StateReadyToClose:
} else if state == StateReadyToClose { c.SetState(StateTerminating)
v.SetState(StateTerminating) case StateTerminating:
} else if state == StateTerminating { c.SetState(StateTerminated)
v.SetState(StateTerminated)
} }
} }
if seg.Option == SegmentOptionClose || seg.Command() == CommandTerminate { if seg.Option == SegmentOptionClose || seg.Command() == CommandTerminate {
v.OnDataInput() c.dataInput.Signal()
v.OnDataOutput() c.dataOutput.Signal()
} }
v.sendingWorker.ProcessReceivingNext(seg.ReceivinNext) c.sendingWorker.ProcessReceivingNext(seg.ReceivinNext)
v.receivingWorker.ProcessSendingNext(seg.SendingNext) c.receivingWorker.ProcessSendingNext(seg.SendingNext)
v.roundTrip.UpdatePeerRTO(seg.PeerRTO, current) c.roundTrip.UpdatePeerRTO(seg.PeerRTO, current)
seg.Release() seg.Release()
default: default:
} }
} }
} }
func (v *Connection) flush() { func (c *Connection) flush() {
current := v.Elapsed() current := c.Elapsed()
if v.State() == StateTerminated { if c.State() == StateTerminated {
return return
} }
if v.State() == StateActive && current-atomic.LoadUint32(&v.lastIncomingTime) >= 30000 { if c.State() == StateActive && current-atomic.LoadUint32(&c.lastIncomingTime) >= 30000 {
v.Close() c.Close()
} }
if v.State() == StateReadyToClose && v.sendingWorker.IsEmpty() { if c.State() == StateReadyToClose && c.sendingWorker.IsEmpty() {
v.SetState(StateTerminating) c.SetState(StateTerminating)
} }
if v.State() == StateTerminating { if c.State() == StateTerminating {
newError("#", v.meta.Conversation, " sending terminating cmd.").AtDebug().WriteToLog() newError("#", c.meta.Conversation, " sending terminating cmd.").AtDebug().WriteToLog()
v.Ping(current, CommandTerminate) c.Ping(current, CommandTerminate)
if current-atomic.LoadUint32(&v.stateBeginTime) > 8000 { if current-atomic.LoadUint32(&c.stateBeginTime) > 8000 {
v.SetState(StateTerminated) c.SetState(StateTerminated)
} }
return return
} }
if v.State() == StatePeerTerminating && current-atomic.LoadUint32(&v.stateBeginTime) > 4000 { if c.State() == StatePeerTerminating && current-atomic.LoadUint32(&c.stateBeginTime) > 4000 {
v.SetState(StateTerminating) c.SetState(StateTerminating)
} }
if v.State() == StateReadyToClose && current-atomic.LoadUint32(&v.stateBeginTime) > 15000 { if c.State() == StateReadyToClose && current-atomic.LoadUint32(&c.stateBeginTime) > 15000 {
v.SetState(StateTerminating) c.SetState(StateTerminating)
} }
// flush acknowledges // flush acknowledges
v.receivingWorker.Flush(current) c.receivingWorker.Flush(current)
v.sendingWorker.Flush(current) c.sendingWorker.Flush(current)
if current-atomic.LoadUint32(&v.lastPingTime) >= 3000 { if current-atomic.LoadUint32(&c.lastPingTime) >= 3000 {
v.Ping(current, CommandPing) c.Ping(current, CommandPing)
} }
} }
func (v *Connection) State() State { func (c *Connection) State() State {
return State(atomic.LoadInt32((*int32)(&v.state))) return State(atomic.LoadInt32((*int32)(&c.state)))
} }
func (v *Connection) Ping(current uint32, cmd Command) { func (c *Connection) Ping(current uint32, cmd Command) {
seg := NewCmdOnlySegment() seg := NewCmdOnlySegment()
seg.Conv = v.meta.Conversation seg.Conv = c.meta.Conversation
seg.Cmd = cmd seg.Cmd = cmd
seg.ReceivinNext = v.receivingWorker.NextNumber() seg.ReceivinNext = c.receivingWorker.NextNumber()
seg.SendingNext = v.sendingWorker.FirstUnacknowledged() seg.SendingNext = c.sendingWorker.FirstUnacknowledged()
seg.PeerRTO = v.roundTrip.Timeout() seg.PeerRTO = c.roundTrip.Timeout()
if v.State() == StateReadyToClose { if c.State() == StateReadyToClose {
seg.Option = SegmentOptionClose seg.Option = SegmentOptionClose
} }
v.output.Write(seg) c.output.Write(seg)
atomic.StoreUint32(&v.lastPingTime, current) atomic.StoreUint32(&c.lastPingTime, current)
seg.Release() seg.Release()
} }

View File

@ -9,6 +9,7 @@ import (
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/platform" "v2ray.com/core/common/platform"
"v2ray.com/core/common/signal"
) )
// NewRay creates a new Ray for direct traffic transport. // NewRay creates a new Ray for direct traffic transport.
@ -57,8 +58,8 @@ type Stream struct {
data buf.MultiBuffer data buf.MultiBuffer
size uint64 size uint64
ctx context.Context ctx context.Context
readSignal chan bool readSignal *signal.Notifier
writeSignal chan bool writeSignal *signal.Notifier
close bool close bool
err bool err bool
} }
@ -67,8 +68,8 @@ type Stream struct {
func NewStream(ctx context.Context) *Stream { func NewStream(ctx context.Context) *Stream {
return &Stream{ return &Stream{
ctx: ctx, ctx: ctx,
readSignal: make(chan bool, 1), readSignal: signal.NewNotifier(),
writeSignal: make(chan bool, 1), writeSignal: signal.NewNotifier(),
size: 0, size: 0,
} }
} }
@ -105,7 +106,7 @@ func (s *Stream) Peek(b *buf.Buffer) {
})) }))
} }
// Read reads data from the Stream. // ReadMultiBuffer reads data from the Stream.
func (s *Stream) ReadMultiBuffer() (buf.MultiBuffer, error) { func (s *Stream) ReadMultiBuffer() (buf.MultiBuffer, error) {
for { for {
mb, err := s.getData() mb, err := s.getData()
@ -114,14 +115,14 @@ func (s *Stream) ReadMultiBuffer() (buf.MultiBuffer, error) {
} }
if mb != nil { if mb != nil {
s.notifyRead() s.readSignal.Signal()
return mb, nil return mb, nil
} }
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return nil, io.EOF return nil, io.EOF
case <-s.writeSignal: case <-s.writeSignal.Wait():
} }
} }
} }
@ -135,7 +136,7 @@ func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
} }
if mb != nil { if mb != nil {
s.notifyRead() s.readSignal.Signal()
return mb, nil return mb, nil
} }
@ -144,7 +145,7 @@ func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
return nil, io.EOF return nil, io.EOF
case <-time.After(timeout): case <-time.After(timeout):
return nil, buf.ErrReadTimeout return nil, buf.ErrReadTimeout
case <-s.writeSignal: case <-s.writeSignal.Wait():
} }
} }
} }
@ -167,7 +168,7 @@ func (s *Stream) waitForStreamSize() error {
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return io.ErrClosedPipe return io.ErrClosedPipe
case <-s.readSignal: case <-s.readSignal.Wait():
if s.err || s.close { if s.err || s.close {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
@ -177,7 +178,7 @@ func (s *Stream) waitForStreamSize() error {
return nil return nil
} }
// Write writes more data into the Stream. // WriteMultiBuffer writes more data into the Stream.
func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error { func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error {
if data.IsEmpty() { if data.IsEmpty() {
return nil return nil
@ -202,31 +203,17 @@ func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error {
s.data.AppendMulti(data) s.data.AppendMulti(data)
s.size += uint64(data.Len()) s.size += uint64(data.Len())
s.notifyWrite() s.writeSignal.Signal()
return nil return nil
} }
func (s *Stream) notifyRead() {
select {
case s.readSignal <- true:
default:
}
}
func (s *Stream) notifyWrite() {
select {
case s.writeSignal <- true:
default:
}
}
// Close closes the stream for writing. Read() still works until EOF. // Close closes the stream for writing. Read() still works until EOF.
func (s *Stream) Close() { func (s *Stream) Close() {
s.access.Lock() s.access.Lock()
s.close = true s.close = true
s.notifyRead() s.readSignal.Signal()
s.notifyWrite() s.writeSignal.Signal()
s.access.Unlock() s.access.Unlock()
} }
@ -239,7 +226,7 @@ func (s *Stream) CloseError() {
s.data = nil s.data = nil
s.size = 0 s.size = 0
} }
s.notifyRead() s.readSignal.Signal()
s.notifyWrite() s.writeSignal.Signal()
s.access.Unlock() s.access.Unlock()
} }