simplify kcp interface

pull/1549/head
Darien Raymond 7 years ago
parent e8e7921613
commit bf7b8798a9
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169

@ -168,13 +168,15 @@ type SystemConnection interface {
Overhead() int
}
var (
_ buf.Reader = (*Connection)(nil)
)
type ConnMetadata struct {
LocalAddr net.Addr
RemoteAddr net.Addr
}
// Connection is a KCP connection over UDP.
type Connection struct {
conn SystemConnection
meta *ConnMetadata
closer io.Closer
rd time.Time
wd time.Time // write deadline
since int64
@ -201,24 +203,24 @@ type Connection struct {
}
// NewConnection create a new KCP connection between local and remote.
func NewConnection(conv uint16, sysConn SystemConnection, config *Config) *Connection {
func NewConnection(conv uint16, meta *ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection {
log.Trace(newError("creating connection ", conv))
conn := &Connection{
conv: conv,
conn: sysConn,
meta: meta,
closer: closer,
since: nowMillisec(),
dataInput: make(chan bool, 1),
dataOutput: make(chan bool, 1),
Config: config,
output: NewRetryableWriter(NewSegmentWriter(sysConn)),
mss: config.GetMTUValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead,
output: NewRetryableWriter(NewSegmentWriter(writer)),
mss: config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead,
roundTrip: &RoundTripInfo{
rto: 100,
minRtt: config.GetTTIValue(),
},
}
sysConn.Reset(conn.Input)
conn.receivingWorker = NewReceivingWorker(conn)
conn.sendingWorker = NewSendingWorker(conn)
@ -413,7 +415,7 @@ func (v *Connection) Close() error {
if state.Is(StateReadyToClose, StateTerminating, StateTerminated) {
return ErrClosedConnection
}
log.Trace(newError("closing connection to ", v.conn.RemoteAddr()))
log.Trace(newError("closing connection to ", v.meta.RemoteAddr))
if state == StateActive {
v.SetState(StateReadyToClose)
@ -433,7 +435,7 @@ func (v *Connection) LocalAddr() net.Addr {
if v == nil {
return nil
}
return v.conn.LocalAddr()
return v.meta.LocalAddr
}
// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
@ -441,7 +443,7 @@ func (v *Connection) RemoteAddr() net.Addr {
if v == nil {
return nil
}
return v.conn.RemoteAddr()
return v.meta.RemoteAddr
}
// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
@ -488,7 +490,7 @@ func (v *Connection) Terminate() {
v.OnDataInput()
v.OnDataOutput()
v.conn.Close()
v.closer.Close()
v.sendingWorker.Release()
v.receivingWorker.Release()
}

@ -1,59 +1,27 @@
package kcp_test
import (
"net"
"io"
"testing"
"time"
"v2ray.com/core/common/buf"
. "v2ray.com/core/transport/internet/kcp"
. "v2ray.com/ext/assert"
)
type NoOpConn struct{}
type NoOpCloser int
func (o *NoOpConn) Overhead() int {
return 0
}
// Write implements io.Writer.
func (o *NoOpConn) Write(b []byte) (int, error) {
return len(b), nil
}
func (o *NoOpConn) Close() error {
return nil
}
func (o *NoOpConn) Read([]byte) (int, error) {
panic("Should not be called.")
}
func (o *NoOpConn) LocalAddr() net.Addr {
return nil
}
func (o *NoOpConn) RemoteAddr() net.Addr {
return nil
}
func (o *NoOpConn) SetDeadline(time.Time) error {
func (NoOpCloser) Close() error {
return nil
}
func (o *NoOpConn) SetReadDeadline(time.Time) error {
return nil
}
func (o *NoOpConn) SetWriteDeadline(time.Time) error {
return nil
}
func (o *NoOpConn) Reset(input func([]Segment)) {}
func TestConnectionReadTimeout(t *testing.T) {
assert := With(t)
conn := NewConnection(1, &NoOpConn{}, &Config{})
conn := NewConnection(1, &ConnMetadata{}, &KCPPacketWriter{
Writer: buf.DiscardBytes,
}, NoOpCloser(0), &Config{})
conn.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 1024)
@ -63,3 +31,11 @@ func TestConnectionReadTimeout(t *testing.T) {
conn.Terminate()
}
func TestConnectionInterface(t *testing.T) {
assert := With(t)
assert((*Connection)(nil), Implements, (*io.Writer)(nil))
assert((*Connection)(nil), Implements, (*io.Reader)(nil))
assert((*Connection)(nil), Implements, (*buf.Reader)(nil))
}

@ -2,9 +2,8 @@ package kcp
import (
"context"
"crypto/cipher"
"crypto/tls"
"sync"
"io"
"sync/atomic"
"v2ray.com/core/app/log"
@ -20,84 +19,20 @@ var (
globalConv = uint32(dice.RollUint16())
)
type ClientConnection struct {
sync.RWMutex
net.Conn
input func([]Segment)
reader PacketReader
writer PacketWriter
}
func (c *ClientConnection) Overhead() int {
c.RLock()
defer c.RUnlock()
if c.writer == nil {
return 0
}
return c.writer.Overhead()
}
// Write implements io.Writer.
func (c *ClientConnection) Write(b []byte) (int, error) {
c.RLock()
defer c.RUnlock()
if c.writer == nil {
return len(b), nil
}
return c.writer.Write(b)
}
func (*ClientConnection) Read([]byte) (int, error) {
panic("KCP|ClientConnection: Read should not be called.")
}
func (c *ClientConnection) Close() error {
return c.Conn.Close()
}
func (c *ClientConnection) Reset(inputCallback func([]Segment)) {
c.Lock()
c.input = inputCallback
c.Unlock()
}
func (c *ClientConnection) ResetSecurity(header internet.PacketHeader, security cipher.AEAD) {
c.Lock()
if c.reader == nil {
c.reader = new(KCPPacketReader)
}
c.reader.(*KCPPacketReader).Header = header
c.reader.(*KCPPacketReader).Security = security
if c.writer == nil {
c.writer = new(KCPPacketWriter)
}
c.writer.(*KCPPacketWriter).Header = header
c.writer.(*KCPPacketWriter).Security = security
c.writer.(*KCPPacketWriter).Writer = c.Conn
c.Unlock()
}
func (c *ClientConnection) Run() {
func fetchInput(ctx context.Context, input io.Reader, reader PacketReader, conn *Connection) {
payload := buf.New()
defer payload.Release()
for {
err := payload.Reset(buf.ReadFrom(c.Conn))
err := payload.Reset(buf.ReadFrom(input))
if err != nil {
payload.Release()
return
}
c.RLock()
if c.input != nil {
segments := c.reader.Read(payload.Bytes())
if len(segments) > 0 {
c.input(segments)
}
segments := reader.Read(payload.Bytes())
if len(segments) > 0 {
conn.Input(segments)
}
c.RUnlock()
}
}
@ -110,10 +45,6 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
if err != nil {
return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err)
}
conn := &ClientConnection{
Conn: rawConn,
}
go conn.Run()
kcpSettings := internet.TransportSettingsFromContext(ctx).(*Config)
@ -125,9 +56,23 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
if err != nil {
return nil, newError("failed to create security").Base(err)
}
conn.ResetSecurity(header, security)
reader := &KCPPacketReader{
Header: header,
Security: security,
}
writer := &KCPPacketWriter{
Header: header,
Security: security,
Writer: rawConn,
}
conv := uint16(atomic.AddUint32(&globalConv, 1))
session := NewConnection(conv, conn, kcpSettings)
session := NewConnection(conv, &ConnMetadata{
LocalAddr: rawConn.LocalAddr(),
RemoteAddr: rawConn.RemoteAddr(),
}, writer, rawConn, kcpSettings)
go fetchInput(ctx, rawConn, reader, session)
var iConn internet.Connection = session

@ -4,9 +4,7 @@ import (
"context"
"crypto/cipher"
"crypto/tls"
"io"
"sync"
"time"
"v2ray.com/core/app/log"
"v2ray.com/core/common"
@ -23,52 +21,6 @@ type ConnectionID struct {
Conv uint16
}
type ServerConnection struct {
local net.Addr
remote net.Addr
writer PacketWriter
closer io.Closer
}
func (c *ServerConnection) Overhead() int {
return c.writer.Overhead()
}
func (*ServerConnection) Read([]byte) (int, error) {
panic("KCP|ServerConnection: Read should not be called.")
}
func (c *ServerConnection) Write(b []byte) (int, error) {
return c.writer.Write(b)
}
func (c *ServerConnection) Close() error {
return c.closer.Close()
}
func (*ServerConnection) Reset(input func([]Segment)) {
}
func (c *ServerConnection) LocalAddr() net.Addr {
return c.local
}
func (c *ServerConnection) RemoteAddr() net.Addr {
return c.remote
}
func (*ServerConnection) SetDeadline(time.Time) error {
return nil
}
func (*ServerConnection) SetReadDeadline(time.Time) error {
return nil
}
func (*ServerConnection) SetWriteDeadline(time.Time) error {
return nil
}
// Listener defines a server listening for connections
type Listener struct {
sync.Mutex
@ -172,17 +124,14 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalD
Port: int(src.Port),
}
localAddr := v.hub.Addr()
sConn := &ServerConnection{
local: localAddr,
remote: remoteAddr,
writer: &KCPPacketWriter{
Header: v.header,
Writer: writer,
Security: v.security,
},
closer: writer,
}
conn = NewConnection(conv, sConn, v.config)
conn = NewConnection(conv, &ConnMetadata{
LocalAddr: localAddr,
RemoteAddr: remoteAddr,
}, &KCPPacketWriter{
Header: v.header,
Security: v.security,
Writer: writer,
}, writer, v.config)
var netConn internet.Connection = conn
if v.tlsConfig != nil {
tlsConn := tls.Server(conn, v.tlsConfig)

Loading…
Cancel
Save