replace channel with pipe in udp conn

pull/1269/head
Darien Raymond 6 years ago
parent f3feec8acf
commit 91109f3657
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169

@ -2,7 +2,6 @@ package inbound
import ( import (
"context" "context"
"io"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -20,6 +19,7 @@ import (
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/tcp" "v2ray.com/core/transport/internet/tcp"
"v2ray.com/core/transport/internet/udp" "v2ray.com/core/transport/internet/udp"
"v2ray.com/core/transport/pipe"
) )
type worker interface { type worker interface {
@ -121,7 +121,8 @@ func (w *tcpWorker) Port() net.Port {
type udpConn struct { type udpConn struct {
lastActivityTime int64 // in seconds lastActivityTime int64 // in seconds
input chan *buf.Buffer reader buf.Reader
writer buf.Writer
output func([]byte) (int, error) output func([]byte) (int, error)
remote net.Addr remote net.Addr
local net.Addr local net.Addr
@ -136,52 +137,21 @@ func (c *udpConn) updateActivity() {
// ReadMultiBuffer implements buf.Reader // ReadMultiBuffer implements buf.Reader
func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) { func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
var payload buf.MultiBuffer mb, err := c.reader.ReadMultiBuffer()
if err != nil {
select { return nil, err
case in := <-c.input:
payload.Append(in)
default:
select {
case in := <-c.input:
payload.Append(in)
case <-c.done.Wait():
return nil, io.EOF
}
}
L:
for {
select {
case in := <-c.input:
payload.Append(in)
default:
break L
}
} }
c.updateActivity() c.updateActivity()
if c.uplink != nil { if c.uplink != nil {
c.uplink.Add(int64(payload.Len())) c.uplink.Add(int64(mb.Len()))
} }
return payload, nil return mb, nil
} }
func (c *udpConn) Read(buf []byte) (int, error) { func (c *udpConn) Read(buf []byte) (int, error) {
select { panic("not implemented")
case in := <-c.input:
defer in.Release()
c.updateActivity()
nBytes := copy(buf, in.Bytes())
if c.uplink != nil {
c.uplink.Add(int64(nBytes))
}
return nBytes, nil
case <-c.done.Wait():
return 0, io.EOF
}
} }
// Write implements io.Writer. // Write implements io.Writer.
@ -198,6 +168,7 @@ func (c *udpConn) Write(buf []byte) (int, error) {
func (c *udpConn) Close() error { func (c *udpConn) Close() error {
common.Must(c.done.Close()) common.Must(c.done.Close())
common.Must(common.Close(c.writer))
return nil return nil
} }
@ -251,8 +222,10 @@ func (w *udpWorker) getConnection(id connID) (*udpConn, bool) {
return conn, true return conn, true
} }
pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
conn := &udpConn{ conn := &udpConn{
input: make(chan *buf.Buffer, 32), reader: pReader,
writer: pWriter,
output: func(b []byte) (int, error) { output: func(b []byte) (int, error) {
return w.hub.WriteTo(b, id.src) return w.hub.WriteTo(b, id.src)
}, },
@ -282,13 +255,9 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
id.dest = originalDest id.dest = originalDest
} }
conn, existing := w.getConnection(id) conn, existing := w.getConnection(id)
select {
case conn.input <- b: // payload will be discarded in pipe is full.
case <-conn.done.Wait(): conn.writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)) // nolint: errcheck
b.Release()
default:
b.Release()
}
if !existing { if !existing {
common.Must(w.checker.Start()) common.Must(w.checker.Start())

@ -28,6 +28,7 @@ type pipe struct {
done *done.Instance done *done.Instance
limit int32 limit int32
state state state state
discardOverflow bool
} }
var errBufferFull = errors.New("buffer full") var errBufferFull = errors.New("buffer full")
@ -121,10 +122,14 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
for { for {
err := p.writeMultiBufferInternal(mb) err := p.writeMultiBufferInternal(mb)
if err == nil { switch {
case err == nil:
p.readSignal.Signal() p.readSignal.Signal()
return nil return nil
} else if err != errBufferFull { case err == errBufferFull && p.discardOverflow:
mb.Release()
return nil
case err != errBufferFull:
mb.Release() mb.Release()
p.readSignal.Signal() p.readSignal.Signal()
return err return err

@ -11,18 +11,28 @@ import (
// Option for creating new Pipes. // Option for creating new Pipes.
type Option func(*pipe) type Option func(*pipe)
// WithoutSizeLimit returns an Option for Pipe to have no size limit.
func WithoutSizeLimit() Option { func WithoutSizeLimit() Option {
return func(p *pipe) { return func(p *pipe) {
p.limit = -1 p.limit = -1
} }
} }
// WithSizeLimit returns an Option for Pipe to have the given size limit.
func WithSizeLimit(limit int32) Option { func WithSizeLimit(limit int32) Option {
return func(p *pipe) { return func(p *pipe) {
p.limit = limit p.limit = limit
} }
} }
// DiscardOverflow returns an Option for Pipe to discard writes if full.
func DiscardOverflow() Option {
return func(p *pipe) {
p.discardOverflow = true
}
}
// OptionsFromContext returns a list of Options from context.
func OptionsFromContext(ctx context.Context) []Option { func OptionsFromContext(ctx context.Context) []Option {
var opt []Option var opt []Option

Loading…
Cancel
Save