simplify kcp interface

pull/1549/head
Darien Raymond 7 years ago
parent e8e7921613
commit bf7b8798a9
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169

@ -168,13 +168,15 @@ type SystemConnection interface {
Overhead() int Overhead() int
} }
var ( type ConnMetadata struct {
_ buf.Reader = (*Connection)(nil) LocalAddr net.Addr
) RemoteAddr net.Addr
}
// Connection is a KCP connection over UDP. // Connection is a KCP connection over UDP.
type Connection struct { type Connection struct {
conn SystemConnection meta *ConnMetadata
closer io.Closer
rd time.Time rd time.Time
wd time.Time // write deadline wd time.Time // write deadline
since int64 since int64
@ -201,24 +203,24 @@ type Connection struct {
} }
// NewConnection create a new KCP connection between local and remote. // NewConnection create a new KCP connection between local and remote.
func NewConnection(conv uint16, sysConn SystemConnection, config *Config) *Connection { func NewConnection(conv uint16, meta *ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection {
log.Trace(newError("creating connection ", conv)) log.Trace(newError("creating connection ", conv))
conn := &Connection{ conn := &Connection{
conv: conv, conv: conv,
conn: sysConn, meta: meta,
closer: closer,
since: nowMillisec(), since: nowMillisec(),
dataInput: make(chan bool, 1), dataInput: make(chan bool, 1),
dataOutput: make(chan bool, 1), dataOutput: make(chan bool, 1),
Config: config, Config: config,
output: NewRetryableWriter(NewSegmentWriter(sysConn)), output: NewRetryableWriter(NewSegmentWriter(writer)),
mss: config.GetMTUValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead, mss: config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead,
roundTrip: &RoundTripInfo{ roundTrip: &RoundTripInfo{
rto: 100, rto: 100,
minRtt: config.GetTTIValue(), minRtt: config.GetTTIValue(),
}, },
} }
sysConn.Reset(conn.Input)
conn.receivingWorker = NewReceivingWorker(conn) conn.receivingWorker = NewReceivingWorker(conn)
conn.sendingWorker = NewSendingWorker(conn) conn.sendingWorker = NewSendingWorker(conn)
@ -413,7 +415,7 @@ func (v *Connection) Close() error {
if state.Is(StateReadyToClose, StateTerminating, StateTerminated) { if state.Is(StateReadyToClose, StateTerminating, StateTerminated) {
return ErrClosedConnection return ErrClosedConnection
} }
log.Trace(newError("closing connection to ", v.conn.RemoteAddr())) log.Trace(newError("closing connection to ", v.meta.RemoteAddr))
if state == StateActive { if state == StateActive {
v.SetState(StateReadyToClose) v.SetState(StateReadyToClose)
@ -433,7 +435,7 @@ func (v *Connection) LocalAddr() net.Addr {
if v == nil { if v == nil {
return nil return nil
} }
return v.conn.LocalAddr() return v.meta.LocalAddr
} }
// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
@ -441,7 +443,7 @@ func (v *Connection) RemoteAddr() net.Addr {
if v == nil { if v == nil {
return nil return nil
} }
return v.conn.RemoteAddr() return v.meta.RemoteAddr
} }
// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
@ -488,7 +490,7 @@ func (v *Connection) Terminate() {
v.OnDataInput() v.OnDataInput()
v.OnDataOutput() v.OnDataOutput()
v.conn.Close() v.closer.Close()
v.sendingWorker.Release() v.sendingWorker.Release()
v.receivingWorker.Release() v.receivingWorker.Release()
} }

@ -1,59 +1,27 @@
package kcp_test package kcp_test
import ( import (
"net" "io"
"testing" "testing"
"time" "time"
"v2ray.com/core/common/buf"
. "v2ray.com/core/transport/internet/kcp" . "v2ray.com/core/transport/internet/kcp"
. "v2ray.com/ext/assert" . "v2ray.com/ext/assert"
) )
type NoOpConn struct{} type NoOpCloser int
func (o *NoOpConn) Overhead() int { func (NoOpCloser) Close() error {
return 0
}
// Write implements io.Writer.
func (o *NoOpConn) Write(b []byte) (int, error) {
return len(b), nil
}
func (o *NoOpConn) Close() error {
return nil
}
func (o *NoOpConn) Read([]byte) (int, error) {
panic("Should not be called.")
}
func (o *NoOpConn) LocalAddr() net.Addr {
return nil
}
func (o *NoOpConn) RemoteAddr() net.Addr {
return nil
}
func (o *NoOpConn) SetDeadline(time.Time) error {
return nil return nil
} }
func (o *NoOpConn) SetReadDeadline(time.Time) error {
return nil
}
func (o *NoOpConn) SetWriteDeadline(time.Time) error {
return nil
}
func (o *NoOpConn) Reset(input func([]Segment)) {}
func TestConnectionReadTimeout(t *testing.T) { func TestConnectionReadTimeout(t *testing.T) {
assert := With(t) assert := With(t)
conn := NewConnection(1, &NoOpConn{}, &Config{}) conn := NewConnection(1, &ConnMetadata{}, &KCPPacketWriter{
Writer: buf.DiscardBytes,
}, NoOpCloser(0), &Config{})
conn.SetReadDeadline(time.Now().Add(time.Second)) conn.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 1024) b := make([]byte, 1024)
@ -63,3 +31,11 @@ func TestConnectionReadTimeout(t *testing.T) {
conn.Terminate() conn.Terminate()
} }
func TestConnectionInterface(t *testing.T) {
assert := With(t)
assert((*Connection)(nil), Implements, (*io.Writer)(nil))
assert((*Connection)(nil), Implements, (*io.Reader)(nil))
assert((*Connection)(nil), Implements, (*buf.Reader)(nil))
}

@ -2,9 +2,8 @@ package kcp
import ( import (
"context" "context"
"crypto/cipher"
"crypto/tls" "crypto/tls"
"sync" "io"
"sync/atomic" "sync/atomic"
"v2ray.com/core/app/log" "v2ray.com/core/app/log"
@ -20,84 +19,20 @@ var (
globalConv = uint32(dice.RollUint16()) globalConv = uint32(dice.RollUint16())
) )
type ClientConnection struct { func fetchInput(ctx context.Context, input io.Reader, reader PacketReader, conn *Connection) {
sync.RWMutex
net.Conn
input func([]Segment)
reader PacketReader
writer PacketWriter
}
func (c *ClientConnection) Overhead() int {
c.RLock()
defer c.RUnlock()
if c.writer == nil {
return 0
}
return c.writer.Overhead()
}
// Write implements io.Writer.
func (c *ClientConnection) Write(b []byte) (int, error) {
c.RLock()
defer c.RUnlock()
if c.writer == nil {
return len(b), nil
}
return c.writer.Write(b)
}
func (*ClientConnection) Read([]byte) (int, error) {
panic("KCP|ClientConnection: Read should not be called.")
}
func (c *ClientConnection) Close() error {
return c.Conn.Close()
}
func (c *ClientConnection) Reset(inputCallback func([]Segment)) {
c.Lock()
c.input = inputCallback
c.Unlock()
}
func (c *ClientConnection) ResetSecurity(header internet.PacketHeader, security cipher.AEAD) {
c.Lock()
if c.reader == nil {
c.reader = new(KCPPacketReader)
}
c.reader.(*KCPPacketReader).Header = header
c.reader.(*KCPPacketReader).Security = security
if c.writer == nil {
c.writer = new(KCPPacketWriter)
}
c.writer.(*KCPPacketWriter).Header = header
c.writer.(*KCPPacketWriter).Security = security
c.writer.(*KCPPacketWriter).Writer = c.Conn
c.Unlock()
}
func (c *ClientConnection) Run() {
payload := buf.New() payload := buf.New()
defer payload.Release() defer payload.Release()
for { for {
err := payload.Reset(buf.ReadFrom(c.Conn)) err := payload.Reset(buf.ReadFrom(input))
if err != nil { if err != nil {
payload.Release() payload.Release()
return return
} }
c.RLock() segments := reader.Read(payload.Bytes())
if c.input != nil { if len(segments) > 0 {
segments := c.reader.Read(payload.Bytes()) conn.Input(segments)
if len(segments) > 0 {
c.input(segments)
}
} }
c.RUnlock()
} }
} }
@ -110,10 +45,6 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
if err != nil { if err != nil {
return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err) return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err)
} }
conn := &ClientConnection{
Conn: rawConn,
}
go conn.Run()
kcpSettings := internet.TransportSettingsFromContext(ctx).(*Config) kcpSettings := internet.TransportSettingsFromContext(ctx).(*Config)
@ -125,9 +56,23 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
if err != nil { if err != nil {
return nil, newError("failed to create security").Base(err) return nil, newError("failed to create security").Base(err)
} }
conn.ResetSecurity(header, security) reader := &KCPPacketReader{
Header: header,
Security: security,
}
writer := &KCPPacketWriter{
Header: header,
Security: security,
Writer: rawConn,
}
conv := uint16(atomic.AddUint32(&globalConv, 1)) conv := uint16(atomic.AddUint32(&globalConv, 1))
session := NewConnection(conv, conn, kcpSettings) session := NewConnection(conv, &ConnMetadata{
LocalAddr: rawConn.LocalAddr(),
RemoteAddr: rawConn.RemoteAddr(),
}, writer, rawConn, kcpSettings)
go fetchInput(ctx, rawConn, reader, session)
var iConn internet.Connection = session var iConn internet.Connection = session

@ -4,9 +4,7 @@ import (
"context" "context"
"crypto/cipher" "crypto/cipher"
"crypto/tls" "crypto/tls"
"io"
"sync" "sync"
"time"
"v2ray.com/core/app/log" "v2ray.com/core/app/log"
"v2ray.com/core/common" "v2ray.com/core/common"
@ -23,52 +21,6 @@ type ConnectionID struct {
Conv uint16 Conv uint16
} }
type ServerConnection struct {
local net.Addr
remote net.Addr
writer PacketWriter
closer io.Closer
}
func (c *ServerConnection) Overhead() int {
return c.writer.Overhead()
}
func (*ServerConnection) Read([]byte) (int, error) {
panic("KCP|ServerConnection: Read should not be called.")
}
func (c *ServerConnection) Write(b []byte) (int, error) {
return c.writer.Write(b)
}
func (c *ServerConnection) Close() error {
return c.closer.Close()
}
func (*ServerConnection) Reset(input func([]Segment)) {
}
func (c *ServerConnection) LocalAddr() net.Addr {
return c.local
}
func (c *ServerConnection) RemoteAddr() net.Addr {
return c.remote
}
func (*ServerConnection) SetDeadline(time.Time) error {
return nil
}
func (*ServerConnection) SetReadDeadline(time.Time) error {
return nil
}
func (*ServerConnection) SetWriteDeadline(time.Time) error {
return nil
}
// Listener defines a server listening for connections // Listener defines a server listening for connections
type Listener struct { type Listener struct {
sync.Mutex sync.Mutex
@ -172,17 +124,14 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalD
Port: int(src.Port), Port: int(src.Port),
} }
localAddr := v.hub.Addr() localAddr := v.hub.Addr()
sConn := &ServerConnection{ conn = NewConnection(conv, &ConnMetadata{
local: localAddr, LocalAddr: localAddr,
remote: remoteAddr, RemoteAddr: remoteAddr,
writer: &KCPPacketWriter{ }, &KCPPacketWriter{
Header: v.header, Header: v.header,
Writer: writer, Security: v.security,
Security: v.security, Writer: writer,
}, }, writer, v.config)
closer: writer,
}
conn = NewConnection(conv, sConn, v.config)
var netConn internet.Connection = conn var netConn internet.Connection = conn
if v.tlsConfig != nil { if v.tlsConfig != nil {
tlsConn := tls.Server(conn, v.tlsConfig) tlsConn := tls.Server(conn, v.tlsConfig)

Loading…
Cancel
Save