diff --git a/transport/ray/direct.go b/transport/ray/direct.go index f2d45b82..134bb44d 100644 --- a/transport/ray/direct.go +++ b/transport/ray/direct.go @@ -40,47 +40,73 @@ func (v *directRay) InboundOutput() InputStream { } type Stream struct { - buffer chan *buf.Buffer + buffer chan *buf.Buffer + srcClose chan bool + destClose chan bool } func NewStream() *Stream { return &Stream{ - buffer: make(chan *buf.Buffer, bufferSize), + buffer: make(chan *buf.Buffer, bufferSize), + srcClose: make(chan bool), + destClose: make(chan bool), } } func (v *Stream) Read() (*buf.Buffer, error) { - buffer, open := <-v.buffer - if !open { - return nil, io.EOF + select { + case <-v.destClose: + return nil, io.ErrClosedPipe + case b := <-v.buffer: + return b, nil + default: + select { + case b := <-v.buffer: + return b, nil + case <-v.srcClose: + return nil, io.EOF + } } - return buffer, nil } func (v *Stream) Write(data *buf.Buffer) (err error) { - defer func() { - if r := recover(); r != nil { - err = io.ErrClosedPipe + select { + case <-v.destClose: + return io.ErrClosedPipe + case <-v.srcClose: + return io.ErrClosedPipe + default: + select { + case <-v.destClose: + return io.ErrClosedPipe + case <-v.srcClose: + return io.ErrClosedPipe + case v.buffer <- data: + return nil } - }() - - v.buffer <- data - return nil + } } func (v *Stream) Close() { defer swallowPanic() - close(v.buffer) + close(v.srcClose) } func (v *Stream) Release() { defer swallowPanic() - close(v.buffer) + close(v.destClose) + v.Close() - for b := range v.buffer { - b.Release() + n := len(v.buffer) + for i := 0; i < n; i++ { + select { + case b := <-v.buffer: + b.Release() + default: + return + } } }