From 06c92e492d296365f2886a2325d928d3e319f9a8 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Thu, 22 Dec 2016 17:28:06 +0100 Subject: [PATCH] fix ray stream --- transport/ray/direct.go | 78 ++++++++++-------------------------- transport/ray/direct_test.go | 42 +++++++++++++++++++ 2 files changed, 63 insertions(+), 57 deletions(-) create mode 100644 transport/ray/direct_test.go diff --git a/transport/ray/direct.go b/transport/ray/direct.go index b3541210..f2d45b82 100644 --- a/transport/ray/direct.go +++ b/transport/ray/direct.go @@ -2,8 +2,6 @@ package ray import ( "io" - "sync" - "time" "v2ray.com/core/common/buf" ) @@ -42,8 +40,6 @@ func (v *directRay) InboundOutput() InputStream { } type Stream struct { - access sync.RWMutex - closed bool buffer chan *buf.Buffer } @@ -54,72 +50,40 @@ func NewStream() *Stream { } func (v *Stream) Read() (*buf.Buffer, error) { - if v.buffer == nil { - return nil, io.EOF - } - v.access.RLock() - if v.buffer == nil { - v.access.RUnlock() - return nil, io.EOF - } - channel := v.buffer - v.access.RUnlock() - result, open := <-channel + buffer, open := <-v.buffer if !open { return nil, io.EOF } - return result, nil + return buffer, nil } -func (v *Stream) Write(data *buf.Buffer) error { - for !v.closed { - err := v.TryWriteOnce(data) - if err != io.ErrNoProgress { - return err +func (v *Stream) Write(data *buf.Buffer) (err error) { + defer func() { + if r := recover(); r != nil { + err = io.ErrClosedPipe } - } - return io.ErrClosedPipe -} + }() -func (v *Stream) TryWriteOnce(data *buf.Buffer) error { - v.access.RLock() - defer v.access.RUnlock() - if v.closed { - return io.ErrClosedPipe - } - select { - case v.buffer <- data: - return nil - case <-time.After(2 * time.Second): - return io.ErrNoProgress - } + v.buffer <- data + return nil } func (v *Stream) Close() { - if v.closed { - return - } - v.access.Lock() - defer v.access.Unlock() - if v.closed { - return - } - v.closed = true + defer swallowPanic() + close(v.buffer) } func (v *Stream) Release() { - if v.buffer == nil { - return - } - v.Close() - v.access.Lock() - defer v.access.Unlock() - if v.buffer == nil { - return - } - for data := range v.buffer { - data.Release() + defer swallowPanic() + + close(v.buffer) + + for b := range v.buffer { + b.Release() } - v.buffer = nil +} + +func swallowPanic() { + recover() } diff --git a/transport/ray/direct_test.go b/transport/ray/direct_test.go new file mode 100644 index 00000000..e8b4ed19 --- /dev/null +++ b/transport/ray/direct_test.go @@ -0,0 +1,42 @@ +package ray_test + +import ( + "io" + "testing" + + "v2ray.com/core/common/buf" + "v2ray.com/core/testing/assert" + . "v2ray.com/core/transport/ray" +) + +func TestStreamIO(t *testing.T) { + assert := assert.On(t) + + stream := NewStream() + assert.Error(stream.Write(buf.New())).IsNil() + + _, err := stream.Read() + assert.Error(err).IsNil() + + stream.Close() + _, err = stream.Read() + assert.Error(err).Equals(io.EOF) + + err = stream.Write(buf.New()) + assert.Error(err).Equals(io.ErrClosedPipe) +} + +func TestStreamClose(t *testing.T) { + assert := assert.On(t) + + stream := NewStream() + assert.Error(stream.Write(buf.New())).IsNil() + + stream.Close() + + _, err := stream.Read() + assert.Error(err).IsNil() + + _, err = stream.Read() + assert.Error(err).Equals(io.EOF) +}