fix ray stream

pull/255/merge
Darien Raymond 8 years ago
parent 56f08afd9c
commit 06c92e492d
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169

@ -2,8 +2,6 @@ package ray
import (
"io"
"sync"
"time"
"v2ray.com/core/common/buf"
)
@ -42,8 +40,6 @@ func (v *directRay) InboundOutput() InputStream {
}
type Stream struct {
access sync.RWMutex
closed bool
buffer chan *buf.Buffer
}
@ -54,72 +50,40 @@ func NewStream() *Stream {
}
func (v *Stream) Read() (*buf.Buffer, error) {
if v.buffer == nil {
return nil, io.EOF
}
v.access.RLock()
if v.buffer == nil {
v.access.RUnlock()
return nil, io.EOF
}
channel := v.buffer
v.access.RUnlock()
result, open := <-channel
buffer, open := <-v.buffer
if !open {
return nil, io.EOF
}
return result, nil
return buffer, nil
}
func (v *Stream) Write(data *buf.Buffer) error {
for !v.closed {
err := v.TryWriteOnce(data)
if err != io.ErrNoProgress {
return err
}
}
return io.ErrClosedPipe
func (v *Stream) Write(data *buf.Buffer) (err error) {
defer func() {
if r := recover(); r != nil {
err = io.ErrClosedPipe
}
}()
func (v *Stream) TryWriteOnce(data *buf.Buffer) error {
v.access.RLock()
defer v.access.RUnlock()
if v.closed {
return io.ErrClosedPipe
}
select {
case v.buffer <- data:
v.buffer <- data
return nil
case <-time.After(2 * time.Second):
return io.ErrNoProgress
}
}
func (v *Stream) Close() {
if v.closed {
return
}
v.access.Lock()
defer v.access.Unlock()
if v.closed {
return
}
v.closed = true
defer swallowPanic()
close(v.buffer)
}
func (v *Stream) Release() {
if v.buffer == nil {
return
}
v.Close()
v.access.Lock()
defer v.access.Unlock()
if v.buffer == nil {
return
defer swallowPanic()
close(v.buffer)
for b := range v.buffer {
b.Release()
}
for data := range v.buffer {
data.Release()
}
v.buffer = nil
func swallowPanic() {
recover()
}

@ -0,0 +1,42 @@
package ray_test
import (
"io"
"testing"
"v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/ray"
)
func TestStreamIO(t *testing.T) {
assert := assert.On(t)
stream := NewStream()
assert.Error(stream.Write(buf.New())).IsNil()
_, err := stream.Read()
assert.Error(err).IsNil()
stream.Close()
_, err = stream.Read()
assert.Error(err).Equals(io.EOF)
err = stream.Write(buf.New())
assert.Error(err).Equals(io.ErrClosedPipe)
}
func TestStreamClose(t *testing.T) {
assert := assert.On(t)
stream := NewStream()
assert.Error(stream.Write(buf.New())).IsNil()
stream.Close()
_, err := stream.Read()
assert.Error(err).IsNil()
_, err = stream.Read()
assert.Error(err).Equals(io.EOF)
}
Loading…
Cancel
Save