refine kcp header and security

pull/314/head
Darien Raymond 2016-12-08 16:27:41 +01:00
parent 0ad629ca31
commit 0917866f38
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
39 changed files with 530 additions and 393 deletions

8
all.go
View File

@ -21,8 +21,8 @@ import (
_ "v2ray.com/core/transport/internet/udp" _ "v2ray.com/core/transport/internet/udp"
_ "v2ray.com/core/transport/internet/ws" _ "v2ray.com/core/transport/internet/ws"
_ "v2ray.com/core/transport/internet/authenticators/http" _ "v2ray.com/core/transport/internet/headers/http"
_ "v2ray.com/core/transport/internet/authenticators/noop" _ "v2ray.com/core/transport/internet/headers/noop"
_ "v2ray.com/core/transport/internet/authenticators/srtp" _ "v2ray.com/core/transport/internet/headers/srtp"
_ "v2ray.com/core/transport/internet/authenticators/utp" _ "v2ray.com/core/transport/internet/headers/utp"
) )

View File

@ -3,10 +3,10 @@ package conf
import ( import (
"v2ray.com/core/common/errors" "v2ray.com/core/common/errors"
"v2ray.com/core/common/loader" "v2ray.com/core/common/loader"
"v2ray.com/core/transport/internet/authenticators/http" "v2ray.com/core/transport/internet/headers/http"
"v2ray.com/core/transport/internet/authenticators/noop" "v2ray.com/core/transport/internet/headers/noop"
"v2ray.com/core/transport/internet/authenticators/srtp" "v2ray.com/core/transport/internet/headers/srtp"
"v2ray.com/core/transport/internet/authenticators/utp" "v2ray.com/core/transport/internet/headers/utp"
) )
type NoOpAuthenticator struct{} type NoOpAuthenticator struct{}

View File

@ -1,28 +0,0 @@
package internet_test
import (
"testing"
"v2ray.com/core/common/loader"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/authenticators/noop"
"v2ray.com/core/transport/internet/authenticators/srtp"
"v2ray.com/core/transport/internet/authenticators/utp"
)
func TestAllAuthenticatorLoadable(t *testing.T) {
assert := assert.On(t)
noopAuth, err := CreateAuthenticator(loader.GetType(new(noop.Config)), nil)
assert.Error(err).IsNil()
assert.Int(noopAuth.Overhead()).Equals(0)
srtp, err := CreateAuthenticator(loader.GetType(new(srtp.Config)), nil)
assert.Error(err).IsNil()
assert.Int(srtp.Overhead()).Equals(4)
utp, err := CreateAuthenticator(loader.GetType(new(utp.Config)), nil)
assert.Error(err).IsNil()
assert.Int(utp.Overhead()).Equals(4)
}

View File

@ -1,8 +0,0 @@
syntax = "proto3";
package v2ray.core.transport.internet.authenticators.noop;
option go_package = "noop";
option java_package = "com.v2ray.core.transport.internet.authenticators.noop";
option java_outer_classname = "ConfigProto";
message Config {}

View File

@ -1,46 +0,0 @@
package noop
import (
"net"
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/loader"
"v2ray.com/core/transport/internet"
)
type NoOpAuthenticator struct{}
func (v NoOpAuthenticator) Overhead() int {
return 0
}
func (v NoOpAuthenticator) Open(payload *alloc.Buffer) bool {
return true
}
func (v NoOpAuthenticator) Seal(payload *alloc.Buffer) {}
type NoOpAuthenticatorFactory struct{}
func (v NoOpAuthenticatorFactory) Create(config interface{}) internet.Authenticator {
return NoOpAuthenticator{}
}
type NoOpConnectionAuthenticator struct{}
func (NoOpConnectionAuthenticator) Client(conn net.Conn) net.Conn {
return conn
}
func (NoOpConnectionAuthenticator) Server(conn net.Conn) net.Conn {
return conn
}
type NoOpConnectionAuthenticatorFactory struct{}
func (NoOpConnectionAuthenticatorFactory) Create(config interface{}) internet.ConnectionAuthenticator {
return NoOpConnectionAuthenticator{}
}
func init() {
internet.RegisterAuthenticator(loader.GetType(new(Config)), NoOpAuthenticatorFactory{})
internet.RegisterConnectionAuthenticator(loader.GetType(new(Config)), NoOpConnectionAuthenticatorFactory{})
}

View File

@ -1,44 +0,0 @@
package srtp
import (
"math/rand"
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/loader"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
type SRTP struct {
header uint16
number uint16
}
func (v *SRTP) Overhead() int {
return 4
}
func (v *SRTP) Open(payload *alloc.Buffer) bool {
payload.SliceFrom(v.Overhead())
return true
}
func (v *SRTP) Seal(payload *alloc.Buffer) {
v.number++
payload.PrependFunc(2, serial.WriteUint16(v.number))
payload.PrependFunc(2, serial.WriteUint16(v.header))
}
type SRTPFactory struct {
}
func (v SRTPFactory) Create(rawSettings interface{}) internet.Authenticator {
return &SRTP{
header: 0xB5E8,
number: uint16(rand.Intn(65536)),
}
}
func init() {
internet.RegisterAuthenticator(loader.GetType(new(Config)), SRTPFactory{})
}

View File

@ -1,10 +0,0 @@
syntax = "proto3";
package v2ray.core.transport.internet.authenticators.utp;
option go_package = "utp";
option java_package = "com.v2ray.core.transport.internet.authenticators.utp";
option java_outer_classname = "ConfigProto";
message Config {
uint32 version = 1;
}

View File

@ -1,44 +0,0 @@
package utp
import (
"math/rand"
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/loader"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
type UTP struct {
header byte
extension byte
connectionId uint16
}
func (v *UTP) Overhead() int {
return 4
}
func (v *UTP) Open(payload *alloc.Buffer) bool {
payload.SliceFrom(v.Overhead())
return true
}
func (v *UTP) Seal(payload *alloc.Buffer) {
payload.PrependFunc(2, serial.WriteUint16(v.connectionId))
payload.PrependBytes(v.header, v.extension)
}
type UTPFactory struct{}
func (v UTPFactory) Create(rawSettings interface{}) internet.Authenticator {
return &UTP{
header: 1,
extension: 0,
connectionId: uint16(rand.Intn(65536)),
}
}
func init() {
internet.RegisterAuthenticator(loader.GetType(new(Config)), UTPFactory{})
}

View File

@ -9,20 +9,6 @@ import (
. "v2ray.com/core/transport/internet" . "v2ray.com/core/transport/internet"
) )
func TestDialDomain(t *testing.T) {
assert := assert.On(t)
server := &tcp.Server{}
dest, err := server.Start()
assert.Error(err).IsNil()
defer server.Close()
conn, err := DialToDest(nil, v2net.TCPDestination(v2net.DomainAddress("local.v2ray.com"), dest.Port))
assert.Error(err).IsNil()
assert.String(conn.RemoteAddr().String()).Equals("127.0.0.1:" + dest.Port.String())
conn.Close()
}
func TestDialWithLocalAddr(t *testing.T) { func TestDialWithLocalAddr(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)

View File

@ -0,0 +1,32 @@
package internet
import "v2ray.com/core/common"
type PacketHeader interface {
Size() int
Write([]byte) int
}
type PacketHeaderFactory interface {
Create(interface{}) PacketHeader
}
var (
headerCache = make(map[string]PacketHeaderFactory)
)
func RegisterPacketHeader(name string, factory PacketHeaderFactory) error {
if _, found := headerCache[name]; found {
return common.ErrDuplicatedName
}
headerCache[name] = factory
return nil
}
func CreatePacketHeader(name string, config interface{}) (PacketHeader, error) {
factory, found := headerCache[name]
if !found {
return nil, common.ErrObjectNotFound
}
return factory.Create(config), nil
}

View File

@ -0,0 +1,28 @@
package internet_test
import (
"testing"
"v2ray.com/core/common/loader"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/headers/noop"
"v2ray.com/core/transport/internet/headers/srtp"
"v2ray.com/core/transport/internet/headers/utp"
)
func TestAllHeadersLoadable(t *testing.T) {
assert := assert.On(t)
noopAuth, err := CreatePacketHeader(loader.GetType(new(noop.Config)), nil)
assert.Error(err).IsNil()
assert.Int(noopAuth.Size()).Equals(0)
srtp, err := CreatePacketHeader(loader.GetType(new(srtp.Config)), nil)
assert.Error(err).IsNil()
assert.Int(srtp.Size()).Equals(4)
utp, err := CreatePacketHeader(loader.GetType(new(utp.Config)), nil)
assert.Error(err).IsNil()
assert.Int(utp.Size()).Equals(4)
}

View File

@ -1,8 +1,8 @@
syntax = "proto3"; syntax = "proto3";
package v2ray.core.transport.internet.authenticators.http; package v2ray.core.transport.internet.headers.http;
option go_package = "http"; option go_package = "http";
option java_package = "com.v2ray.core.transport.internet.authenticators.http"; option java_package = "com.v2ray.core.transport.internet.headers.http";
option java_outer_classname = "ConfigProto"; option java_outer_classname = "ConfigProto";
message Header { message Header {

View File

@ -6,7 +6,7 @@ import (
"v2ray.com/core/common/alloc" "v2ray.com/core/common/alloc"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet/authenticators/http" . "v2ray.com/core/transport/internet/headers/http"
) )
func TestReaderWriter(t *testing.T) { func TestReaderWriter(t *testing.T) {

View File

@ -0,0 +1,8 @@
syntax = "proto3";
package v2ray.core.transport.internet.headers.noop;
option go_package = "noop";
option java_package = "com.v2ray.core.transport.internet.headers.noop";
option java_outer_classname = "ConfigProto";
message Config {}

View File

@ -0,0 +1,44 @@
package noop
import (
"net"
"v2ray.com/core/common/loader"
"v2ray.com/core/transport/internet"
)
type NoOpHeader struct{}
func (v NoOpHeader) Size() int {
return 0
}
func (v NoOpHeader) Write([]byte) int {
return 0
}
type NoOpHeaderFactory struct{}
func (v NoOpHeaderFactory) Create(config interface{}) internet.PacketHeader {
return NoOpHeader{}
}
type NoOpConnectionHeader struct{}
func (NoOpConnectionHeader) Client(conn net.Conn) net.Conn {
return conn
}
func (NoOpConnectionHeader) Server(conn net.Conn) net.Conn {
return conn
}
type NoOpConnectionHeaderFactory struct{}
func (NoOpConnectionHeaderFactory) Create(config interface{}) internet.ConnectionAuthenticator {
return NoOpConnectionHeader{}
}
func init() {
internet.RegisterPacketHeader(loader.GetType(new(Config)), NoOpHeaderFactory{})
internet.RegisterConnectionAuthenticator(loader.GetType(new(Config)), NoOpConnectionHeaderFactory{})
}

View File

@ -1,8 +1,8 @@
syntax = "proto3"; syntax = "proto3";
package v2ray.core.transport.internet.authenticators.srtp; package v2ray.core.transport.internet.headers.srtp;
option go_package = "srtp"; option go_package = "srtp";
option java_package = "com.v2ray.core.transport.internet.authenticators.srtp"; option java_package = "com.v2ray.core.transport.internet.headers.srtp";
option java_outer_classname = "ConfigProto"; option java_outer_classname = "ConfigProto";
message Config { message Config {

View File

@ -0,0 +1,39 @@
package srtp
import (
"math/rand"
"v2ray.com/core/common/loader"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
type SRTP struct {
header uint16
number uint16
}
func (v *SRTP) Size() int {
return 4
}
func (v *SRTP) Write(b []byte) int {
v.number++
b = serial.Uint16ToBytes(v.number, b[:0])
b = serial.Uint16ToBytes(v.number, b)
return 4
}
type SRTPFactory struct {
}
func (v SRTPFactory) Create(rawSettings interface{}) internet.PacketHeader {
return &SRTP{
header: 0xB5E8,
number: uint16(rand.Intn(65536)),
}
}
func init() {
internet.RegisterPacketHeader(loader.GetType(new(Config)), SRTPFactory{})
}

View File

@ -5,19 +5,18 @@ import (
"v2ray.com/core/common/alloc" "v2ray.com/core/common/alloc"
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet/authenticators/srtp" . "v2ray.com/core/transport/internet/headers/srtp"
) )
func TestSRTPOpenSeal(t *testing.T) { func TestSRTPWrite(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
content := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'} content := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'}
srtp := SRTP{}
payload := alloc.NewLocalBuffer(2048) payload := alloc.NewLocalBuffer(2048)
payload.AppendFunc(srtp.Write)
payload.Append(content) payload.Append(content)
srtp := SRTP{} assert.Int(payload.Len()).Equals(len(content) + srtp.Size())
srtp.Seal(payload)
assert.Int(payload.Len()).GreaterThan(len(content))
assert.Bool(srtp.Open(payload)).IsTrue()
assert.Bytes(content).Equals(payload.Bytes())
} }

View File

@ -0,0 +1,10 @@
syntax = "proto3";
package v2ray.core.transport.internet.headers.utp;
option go_package = "utp";
option java_package = "com.v2ray.core.transport.internet.headers.utp";
option java_outer_classname = "ConfigProto";
message Config {
uint32 version = 1;
}

View File

@ -0,0 +1,39 @@
package utp
import (
"math/rand"
"v2ray.com/core/common/loader"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
type UTP struct {
header byte
extension byte
connectionId uint16
}
func (v *UTP) Size() int {
return 4
}
func (v *UTP) Write(b []byte) int {
b = serial.Uint16ToBytes(v.connectionId, b[:0])
b = append(b, v.header, v.extension)
return 4
}
type UTPFactory struct{}
func (v UTPFactory) Create(rawSettings interface{}) internet.PacketHeader {
return &UTP{
header: 1,
extension: 0,
connectionId: uint16(rand.Intn(65536)),
}
}
func init() {
internet.RegisterPacketHeader(loader.GetType(new(Config)), UTPFactory{})
}

View File

@ -5,19 +5,18 @@ import (
"v2ray.com/core/common/alloc" "v2ray.com/core/common/alloc"
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet/authenticators/utp" . "v2ray.com/core/transport/internet/headers/utp"
) )
func TestUTPOpenSeal(t *testing.T) { func TestUTPWrite(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
content := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'} content := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'}
utp := UTP{}
payload := alloc.NewLocalBuffer(2048) payload := alloc.NewLocalBuffer(2048)
payload.AppendFunc(utp.Write)
payload.Append(content) payload.Append(content)
utp := UTP{} assert.Int(payload.Len()).Equals(len(content) + utp.Size())
utp.Seal(payload)
assert.Int(payload.Len()).GreaterThan(len(content))
assert.Bool(utp.Open(payload)).IsTrue()
assert.Bytes(content).Equals(payload.Bytes())
} }

View File

@ -1,6 +1,8 @@
package kcp package kcp
import ( import (
"crypto/cipher"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
) )
@ -47,21 +49,20 @@ func (v *ReadBuffer) GetSize() uint32 {
return v.Size return v.Size
} }
func (v *Config) GetAuthenticator() (internet.Authenticator, error) { func (v *Config) GetSecurity() (cipher.AEAD, error) {
auth := NewSimpleAuthenticator() return NewSimpleAuthenticator(), nil
}
func (v *Config) GetPackerHeader() (internet.PacketHeader, error) {
if v.HeaderConfig != nil { if v.HeaderConfig != nil {
rawConfig, err := v.HeaderConfig.GetInstance() rawConfig, err := v.HeaderConfig.GetInstance()
if err != nil { if err != nil {
return nil, err return nil, err
} }
header, err := internet.CreateAuthenticator(v.HeaderConfig.Type, rawConfig) return internet.CreatePacketHeader(v.HeaderConfig.Type, rawConfig)
if err != nil {
return nil, err
}
auth = internet.NewAuthenticatorChain(header, auth)
} }
return auth, nil return nil, nil
} }
func (v *Config) GetSendingInFlightSize() uint32 { func (v *Config) GetSendingInFlightSize() uint32 {

View File

@ -6,6 +6,7 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"v2ray.com/core/common/errors" "v2ray.com/core/common/errors"
"v2ray.com/core/common/log" "v2ray.com/core/common/log"
"v2ray.com/core/common/predicate" "v2ray.com/core/common/predicate"
@ -161,7 +162,8 @@ func (v *Updater) Run() {
type SystemConnection interface { type SystemConnection interface {
net.Conn net.Conn
Id() internal.ConnectionId Id() internal.ConnectionId
Reset(internet.Authenticator, func([]byte)) Reset(func([]Segment))
Overhead() int
} }
// Connection is a KCP connection over UDP. // Connection is a KCP connection over UDP.
@ -197,32 +199,25 @@ type Connection struct {
} }
// NewConnection create a new KCP connection between local and remote. // NewConnection create a new KCP connection between local and remote.
func NewConnection(conv uint16, sysConn SystemConnection, recycler internal.ConnectionRecyler, block internet.Authenticator, config *Config) *Connection { func NewConnection(conv uint16, sysConn SystemConnection, recycler internal.ConnectionRecyler, config *Config) *Connection {
log.Info("KCP|Connection: creating connection ", conv) log.Info("KCP|Connection: creating connection ", conv)
authWriter := &AuthenticationWriter{
Authenticator: block,
Writer: sysConn,
Config: config,
}
conn := &Connection{ conn := &Connection{
conv: conv, conv: conv,
conn: sysConn, conn: sysConn,
connRecycler: recycler, connRecycler: recycler,
block: block,
since: nowMillisec(), since: nowMillisec(),
dataInputCond: sync.NewCond(new(sync.Mutex)), dataInputCond: sync.NewCond(new(sync.Mutex)),
dataOutputCond: sync.NewCond(new(sync.Mutex)), dataOutputCond: sync.NewCond(new(sync.Mutex)),
Config: config, Config: config,
output: NewSegmentWriter(authWriter), output: NewSegmentWriter(sysConn, config.GetMtu().GetValue()-uint32(sysConn.Overhead())),
mss: authWriter.Mtu() - DataSegmentOverhead, mss: config.GetMtu().GetValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead,
roundTrip: &RoundTripInfo{ roundTrip: &RoundTripInfo{
rto: 100, rto: 100,
minRtt: config.Tti.GetValue(), minRtt: config.Tti.GetValue(),
}, },
} }
sysConn.Reset(block, conn.Input) sysConn.Reset(conn.Input)
conn.receivingWorker = NewReceivingWorker(conn) conn.receivingWorker = NewReceivingWorker(conn)
conn.sendingWorker = NewSendingWorker(conn) conn.sendingWorker = NewSendingWorker(conn)
@ -480,16 +475,11 @@ func (v *Connection) OnPeerClosed() {
} }
// Input when you received a low level packet (eg. UDP packet), call it // Input when you received a low level packet (eg. UDP packet), call it
func (v *Connection) Input(data []byte) { func (v *Connection) Input(segments []Segment) {
current := v.Elapsed() current := v.Elapsed()
atomic.StoreUint32(&v.lastIncomingTime, current) atomic.StoreUint32(&v.lastIncomingTime, current)
var seg Segment for _, seg := range segments {
for {
seg, data = ReadSegment(data)
if seg == nil {
break
}
if seg.Conversation() != v.conv { if seg.Conversation() != v.conv {
return return
} }
@ -507,7 +497,7 @@ func (v *Connection) Input(data []byte) {
v.dataUpdater.WakeUp() v.dataUpdater.WakeUp()
case *CmdOnlySegment: case *CmdOnlySegment:
v.HandleOption(seg.Option) v.HandleOption(seg.Option)
if seg.Command == CommandTerminate { if seg.Command() == CommandTerminate {
state := v.State() state := v.State()
if state == StateActive || if state == StateActive ||
state == StatePeerClosed { state == StatePeerClosed {
@ -577,7 +567,7 @@ func (v *Connection) State() State {
func (v *Connection) Ping(current uint32, cmd Command) { func (v *Connection) Ping(current uint32, cmd Command) {
seg := NewCmdOnlySegment() seg := NewCmdOnlySegment()
seg.Conv = v.conv seg.Conv = v.conv
seg.Command = cmd seg.Cmd = cmd
seg.ReceivinNext = v.receivingWorker.nextNumber seg.ReceivinNext = v.receivingWorker.nextNumber
seg.SendingNext = v.sendingWorker.firstUnacknowledged seg.SendingNext = v.sendingWorker.firstUnacknowledged
seg.PeerRTO = v.roundTrip.Timeout() seg.PeerRTO = v.roundTrip.Timeout()

View File

@ -4,14 +4,18 @@ import (
"net" "net"
"testing" "testing"
"time" "time"
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/internal" "v2ray.com/core/transport/internet/internal"
. "v2ray.com/core/transport/internet/kcp" . "v2ray.com/core/transport/internet/kcp"
) )
type NoOpConn struct{} type NoOpConn struct{}
func (o *NoOpConn) Overhead() int {
return 0
}
func (o *NoOpConn) Write(b []byte) (int, error) { func (o *NoOpConn) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
@ -48,7 +52,7 @@ func (o *NoOpConn) Id() internal.ConnectionId {
return internal.ConnectionId{} return internal.ConnectionId{}
} }
func (o *NoOpConn) Reset(auth internet.Authenticator, input func([]byte)) {} func (o *NoOpConn) Reset(input func([]Segment)) {}
type NoOpRecycler struct{} type NoOpRecycler struct{}
@ -57,7 +61,7 @@ func (o *NoOpRecycler) Put(internal.ConnectionId, net.Conn) {}
func TestConnectionReadTimeout(t *testing.T) { func TestConnectionReadTimeout(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
conn := NewConnection(1, &NoOpConn{}, &NoOpRecycler{}, NewSimpleAuthenticator(), &Config{}) conn := NewConnection(1, &NoOpConn{}, &NoOpRecycler{}, &Config{})
conn.SetReadDeadline(time.Now().Add(time.Second)) conn.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 1024) b := make([]byte, 1024)

View File

@ -1,63 +1,74 @@
package kcp package kcp
import ( import (
"crypto/cipher"
"errors"
"hash/fnv" "hash/fnv"
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet" )
var (
errInvalidAuth = errors.New("Invalid auth.")
) )
type SimpleAuthenticator struct{} type SimpleAuthenticator struct{}
func NewSimpleAuthenticator() internet.Authenticator { func NewSimpleAuthenticator() cipher.AEAD {
return &SimpleAuthenticator{} return &SimpleAuthenticator{}
} }
func (v *SimpleAuthenticator) NonceSize() int {
return 0
}
func (v *SimpleAuthenticator) Overhead() int { func (v *SimpleAuthenticator) Overhead() int {
return 6 return 6
} }
func (v *SimpleAuthenticator) Seal(buffer *alloc.Buffer) { func (v *SimpleAuthenticator) Seal(dst, nonce, plain, extra []byte) []byte {
buffer.PrependFunc(2, serial.WriteUint16(uint16(buffer.Len()))) dst = append(dst, 0, 0, 0, 0)
fnvHash := fnv.New32a() dst = serial.Uint16ToBytes(uint16(len(plain)), dst)
fnvHash.Write(buffer.Bytes()) dst = append(dst, plain...)
buffer.PrependFunc(4, serial.WriteHash(fnvHash))
len := buffer.Len() fnvHash := fnv.New32a()
fnvHash.Write(dst[4:])
fnvHash.Sum(dst[:0])
len := len(dst)
xtra := 4 - len%4 xtra := 4 - len%4
if xtra != 0 { if xtra != 4 {
buffer.Slice(0, len+xtra) dst = append(dst, make([]byte, xtra)...)
} }
xorfwd(buffer.Bytes()) xorfwd(dst)
if xtra != 0 { if xtra != 4 {
buffer.Slice(0, len) dst = dst[:len]
} }
return dst
} }
func (v *SimpleAuthenticator) Open(buffer *alloc.Buffer) bool { func (v *SimpleAuthenticator) Open(dst, nonce, cipherText, extra []byte) ([]byte, error) {
len := buffer.Len() dst = append(dst, cipherText...)
xtra := 4 - len%4 dstLen := len(dst)
if xtra != 0 { xtra := 4 - dstLen%4
buffer.Slice(0, len+xtra) if xtra != 4 {
dst = append(dst, make([]byte, xtra)...)
} }
xorbkd(buffer.Bytes()) xorbkd(dst)
if xtra != 0 { if xtra != 4 {
buffer.Slice(0, len) dst = dst[:dstLen]
} }
fnvHash := fnv.New32a() fnvHash := fnv.New32a()
fnvHash.Write(buffer.BytesFrom(4)) fnvHash.Write(dst[4:])
if serial.BytesToUint32(buffer.BytesTo(4)) != fnvHash.Sum32() { if serial.BytesToUint32(dst[:4]) != fnvHash.Sum32() {
return false return nil, errInvalidAuth
} }
length := serial.BytesToUint16(buffer.BytesRange(4, 6)) length := serial.BytesToUint16(dst[4:6])
if buffer.Len()-6 != int(length) { if len(dst)-6 != int(length) {
return false return nil, errInvalidAuth
} }
buffer.SliceFrom(6) return dst[6:], nil
return true
} }

View File

@ -1,10 +1,8 @@
package kcp_test package kcp_test
import ( import (
"crypto/rand"
"testing" "testing"
"v2ray.com/core/common/alloc"
"v2ray.com/core/testing/assert" "v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet/kcp" . "v2ray.com/core/transport/internet/kcp"
) )
@ -12,38 +10,27 @@ import (
func TestSimpleAuthenticator(t *testing.T) { func TestSimpleAuthenticator(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
buffer := alloc.NewLocalBuffer(512) cache := make([]byte, 512)
buffer.AppendBytes('a', 'b', 'c', 'd', 'e', 'f', 'g')
payload := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'}
auth := NewSimpleAuthenticator() auth := NewSimpleAuthenticator()
auth.Seal(buffer) b := auth.Seal(cache[:0], nil, payload, nil)
c, err := auth.Open(cache[:0], nil, b, nil)
assert.Bool(auth.Open(buffer)).IsTrue() assert.Error(err).IsNil()
assert.Bytes(buffer.Bytes()).Equals([]byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'}) assert.Bytes(c).Equals(payload)
} }
func TestSimpleAuthenticator2(t *testing.T) { func TestSimpleAuthenticator2(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
buffer := alloc.NewLocalBuffer(512) cache := make([]byte, 512)
buffer.AppendBytes('1', '2')
payload := []byte{'a', 'b'}
auth := NewSimpleAuthenticator() auth := NewSimpleAuthenticator()
auth.Seal(buffer) b := auth.Seal(cache[:0], nil, payload, nil)
c, err := auth.Open(cache[:0], nil, b, nil)
assert.Bool(auth.Open(buffer)).IsTrue() assert.Error(err).IsNil()
assert.Bytes(buffer.Bytes()).Equals([]byte{'1', '2'}) assert.Bytes(c).Equals(payload)
}
func BenchmarkSimpleAuthenticator(b *testing.B) {
buffer := alloc.NewLocalBuffer(2048)
buffer.FillFullFrom(rand.Reader, 1024)
auth := NewSimpleAuthenticator()
b.SetBytes(int64(buffer.Len()))
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth.Seal(buffer)
auth.Open(buffer)
}
} }

View File

@ -5,8 +5,12 @@ import (
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"crypto/cipher"
"v2ray.com/core/common/alloc" "v2ray.com/core/common/alloc"
"v2ray.com/core/common/dice" "v2ray.com/core/common/dice"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/log" "v2ray.com/core/common/log"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
@ -20,11 +24,32 @@ var (
) )
type ClientConnection struct { type ClientConnection struct {
sync.Mutex sync.RWMutex
net.Conn net.Conn
id internal.ConnectionId id internal.ConnectionId
input func([]byte) input func([]Segment)
auth internet.Authenticator reader PacketReader
writer PacketWriter
}
func (o *ClientConnection) Overhead() int {
o.RLock()
defer o.RUnlock()
if o.writer == nil {
return 0
}
return o.writer.Overhead()
}
func (o *ClientConnection) Write(b []byte) (int, error) {
o.RLock()
defer o.RUnlock()
if o.writer == nil {
return len(b), nil
}
return o.writer.Write(b)
} }
func (o *ClientConnection) Read([]byte) (int, error) { func (o *ClientConnection) Read([]byte) (int, error) {
@ -39,10 +64,26 @@ func (o *ClientConnection) Close() error {
return o.Conn.Close() return o.Conn.Close()
} }
func (o *ClientConnection) Reset(auth internet.Authenticator, inputCallback func([]byte)) { func (o *ClientConnection) Reset(inputCallback func([]Segment)) {
o.Lock() o.Lock()
o.input = inputCallback o.input = inputCallback
o.auth = auth o.Unlock()
}
func (o *ClientConnection) ResetSecurity(header internet.PacketHeader, security cipher.AEAD) {
o.Lock()
if o.reader == nil {
o.reader = new(KCPPacketReader)
}
o.reader.(*KCPPacketReader).Header = header
o.reader.(*KCPPacketReader).Security = security
if o.writer == nil {
o.writer = new(KCPPacketWriter)
}
o.writer.(*KCPPacketWriter).Header = header
o.writer.(*KCPPacketWriter).Security = security
o.writer.(*KCPPacketWriter).Writer = o.Conn
o.Unlock() o.Unlock()
} }
@ -57,12 +98,14 @@ func (o *ClientConnection) Run() {
payload.Release() payload.Release()
return return
} }
o.Lock() o.RLock()
if o.input != nil && o.auth.Open(payload) { if o.input != nil {
o.input(payload.Bytes()) segments := o.reader.Read(payload.Bytes())
if len(segments) > 0 {
o.input(segments)
}
} }
o.Unlock() o.RUnlock()
payload.Reset()
} }
} }
@ -93,13 +136,18 @@ func DialKCP(src v2net.Address, dest v2net.Destination, options internet.DialerO
} }
kcpSettings := networkSettings.(*Config) kcpSettings := networkSettings.(*Config)
cpip, err := kcpSettings.GetAuthenticator() clientConn := conn.(*ClientConnection)
header, err := kcpSettings.GetPackerHeader()
if err != nil { if err != nil {
log.Error("KCP|Dialer: Failed to create authenticator: ", err) return nil, errors.Base(err).Message("KCP|Dialer: Failed to create packet header.")
return nil, err
} }
security, err := kcpSettings.GetSecurity()
if err != nil {
return nil, errors.Base(err).Message("KCP|Dialer: Failed to create security.")
}
clientConn.ResetSecurity(header, security)
conv := uint16(atomic.AddUint32(&globalConv, 1)) conv := uint16(atomic.AddUint32(&globalConv, 1))
session := NewConnection(conv, conn.(*ClientConnection), globalPool, cpip, kcpSettings) session := NewConnection(conv, clientConn, globalPool, kcpSettings)
var iConn internet.Connection var iConn internet.Connection
iConn = session iConn = session

View File

@ -0,0 +1,92 @@
package kcp
import (
"crypto/cipher"
"crypto/rand"
"io"
"v2ray.com/core/transport/internet"
)
type PacketReader interface {
Read([]byte) []Segment
}
type PacketWriter interface {
Overhead() int
io.Writer
}
type KCPPacketReader struct {
Security cipher.AEAD
Header internet.PacketHeader
}
func (v *KCPPacketReader) Read(b []byte) []Segment {
if v.Header != nil {
b = b[v.Header.Size():]
}
if v.Security != nil {
nonceSize := v.Security.NonceSize()
out, err := v.Security.Open(b[nonceSize:nonceSize], b[:nonceSize], b[nonceSize:], nil)
if err != nil {
return nil
}
b = out
}
var result []Segment
for len(b) > 0 {
seg, x := ReadSegment(b)
if seg == nil {
break
}
result = append(result, seg)
b = x
}
return result
}
type KCPPacketWriter struct {
Header internet.PacketHeader
Security cipher.AEAD
Writer io.Writer
buffer [32 * 1024]byte
}
func (v *KCPPacketWriter) Overhead() int {
overhead := 0
if v.Header != nil {
overhead += v.Header.Size()
}
if v.Security != nil {
overhead += v.Security.Overhead()
}
return overhead
}
func (v *KCPPacketWriter) Write(b []byte) (int, error) {
x := v.buffer[:]
size := 0
if v.Header != nil {
nBytes := v.Header.Write(x)
size += nBytes
x = x[nBytes:]
}
if v.Security != nil {
nonceSize := v.Security.NonceSize()
var nonce []byte
if nonceSize > 0 {
nonce = x[:nonceSize]
rand.Read(nonce)
x = x[nonceSize:]
}
x = v.Security.Seal(x[:0], nonce, b, nil)
size += nonceSize + len(x)
} else {
size += copy(x, b)
}
_, err := v.Writer.Write(v.buffer[:size])
return len(b), err
}

View File

@ -0,0 +1 @@
package kcp_test

View File

@ -2,14 +2,17 @@ package kcp
import ( import (
"crypto/tls" "crypto/tls"
"io"
"net" "net"
"sync" "sync"
"time" "time"
"crypto/cipher"
"v2ray.com/core/common/alloc" "v2ray.com/core/common/alloc"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/log" "v2ray.com/core/common/log"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/proxy" "v2ray.com/core/proxy"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/internal" "v2ray.com/core/transport/internet/internal"
@ -25,11 +28,14 @@ type ConnectionId struct {
type ServerConnection struct { type ServerConnection struct {
id internal.ConnectionId id internal.ConnectionId
writer *Writer
local net.Addr local net.Addr
remote net.Addr remote net.Addr
auth internet.Authenticator writer PacketWriter
input func([]byte) closer io.Closer
}
func (o *ServerConnection) Overhead() int {
return o.writer.Overhead()
} }
func (o *ServerConnection) Read([]byte) (int, error) { func (o *ServerConnection) Read([]byte) (int, error) {
@ -41,20 +47,10 @@ func (o *ServerConnection) Write(b []byte) (int, error) {
} }
func (o *ServerConnection) Close() error { func (o *ServerConnection) Close() error {
return o.writer.Close() return o.closer.Close()
} }
func (o *ServerConnection) Reset(auth internet.Authenticator, input func([]byte)) { func (o *ServerConnection) Reset(input func([]Segment)) {
o.auth = auth
o.input = input
}
func (o *ServerConnection) Input(b *alloc.Buffer) {
defer b.Release()
if o.auth.Open(b) {
o.input(b.Bytes())
}
} }
func (o *ServerConnection) LocalAddr() net.Addr { func (o *ServerConnection) LocalAddr() net.Addr {
@ -85,12 +81,14 @@ func (o *ServerConnection) Id() internal.ConnectionId {
type Listener struct { type Listener struct {
sync.Mutex sync.Mutex
running bool running bool
authenticator internet.Authenticator
sessions map[ConnectionId]*Connection sessions map[ConnectionId]*Connection
awaitingConns chan *Connection awaitingConns chan *Connection
hub *udp.UDPHub hub *udp.UDPHub
tlsConfig *tls.Config tlsConfig *tls.Config
config *Config config *Config
reader PacketReader
header internet.PacketHeader
security cipher.AEAD
} }
func NewListener(address v2net.Address, port v2net.Port, options internet.ListenOptions) (*Listener, error) { func NewListener(address v2net.Address, port v2net.Port, options internet.ListenOptions) (*Listener, error) {
@ -102,12 +100,21 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
kcpSettings := networkSettings.(*Config) kcpSettings := networkSettings.(*Config)
kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false} kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false}
auth, err := kcpSettings.GetAuthenticator() header, err := kcpSettings.GetPackerHeader()
if err != nil { if err != nil {
return nil, err return nil, errors.Base(err).Message("KCP|Listener: Failed to create packet header.")
}
security, err := kcpSettings.GetSecurity()
if err != nil {
return nil, errors.Base(err).Message("KCP|Listener: Failed to create security.")
} }
l := &Listener{ l := &Listener{
authenticator: auth, header: header,
security: security,
reader: &KCPPacketReader{
Header: header,
Security: security,
},
sessions: make(map[ConnectionId]*Connection), sessions: make(map[ConnectionId]*Connection),
awaitingConns: make(chan *Connection, 64), awaitingConns: make(chan *Connection, 64),
running: true, running: true,
@ -138,10 +145,12 @@ func (v *Listener) OnReceive(payload *alloc.Buffer, session *proxy.SessionInfo)
src := session.Source src := session.Source
if valid := v.authenticator.Open(payload); !valid { segments := v.reader.Read(payload.Bytes())
if len(segments) == 0 {
log.Info("KCP|Listener: discarding invalid payload from ", src) log.Info("KCP|Listener: discarding invalid payload from ", src)
return return
} }
if !v.running { if !v.running {
return return
} }
@ -153,8 +162,9 @@ func (v *Listener) OnReceive(payload *alloc.Buffer, session *proxy.SessionInfo)
if payload.Len() < 4 { if payload.Len() < 4 {
return return
} }
conv := serial.BytesToUint16(payload.BytesTo(2)) conv := segments[0].Conversation()
cmd := Command(payload.Byte(2)) cmd := segments[0].Command()
id := ConnectionId{ id := ConnectionId{
Remote: src.Address, Remote: src.Address,
Port: src.Port, Port: src.Port,
@ -177,17 +187,18 @@ func (v *Listener) OnReceive(payload *alloc.Buffer, session *proxy.SessionInfo)
Port: int(src.Port), Port: int(src.Port),
} }
localAddr := v.hub.Addr() localAddr := v.hub.Addr()
auth, err := v.config.GetAuthenticator()
if err != nil {
log.Error("KCP|Listener: Failed to create authenticator: ", err)
}
sConn := &ServerConnection{ sConn := &ServerConnection{
id: internal.NewConnectionId(v2net.LocalHostIP, src), id: internal.NewConnectionId(v2net.LocalHostIP, src),
local: localAddr, local: localAddr,
remote: remoteAddr, remote: remoteAddr,
writer: writer, writer: &KCPPacketWriter{
Header: v.header,
Writer: writer,
Security: v.security,
},
closer: writer,
} }
conn = NewConnection(conv, sConn, v, auth, v.config) conn = NewConnection(conv, sConn, v, v.config)
select { select {
case v.awaitingConns <- conn: case v.awaitingConns <- conn:
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):
@ -196,7 +207,7 @@ func (v *Listener) OnReceive(payload *alloc.Buffer, session *proxy.SessionInfo)
} }
v.sessions[id] = conn v.sessions[id] = conn
} }
conn.Input(payload.Bytes()) conn.Input(segments)
} }
func (v *Listener) Remove(id ConnectionId) { func (v *Listener) Remove(id ConnectionId) {

View File

@ -5,8 +5,6 @@ import (
"sync" "sync"
"v2ray.com/core/common/alloc" "v2ray.com/core/common/alloc"
v2io "v2ray.com/core/common/io"
"v2ray.com/core/transport/internet"
) )
type SegmentWriter interface { type SegmentWriter interface {
@ -17,13 +15,14 @@ type BufferedSegmentWriter struct {
sync.Mutex sync.Mutex
mtu uint32 mtu uint32
buffer *alloc.Buffer buffer *alloc.Buffer
writer v2io.Writer writer io.Writer
} }
func NewSegmentWriter(writer *AuthenticationWriter) *BufferedSegmentWriter { func NewSegmentWriter(writer io.Writer, mtu uint32) *BufferedSegmentWriter {
return &BufferedSegmentWriter{ return &BufferedSegmentWriter{
mtu: writer.Mtu(), mtu: mtu,
writer: writer, writer: writer,
buffer: alloc.NewSmallBuffer(),
} }
} }
@ -36,45 +35,21 @@ func (v *BufferedSegmentWriter) Write(seg Segment) {
v.FlushWithoutLock() v.FlushWithoutLock()
} }
if v.buffer == nil {
v.buffer = alloc.NewSmallBuffer()
}
v.buffer.AppendFunc(seg.Bytes()) v.buffer.AppendFunc(seg.Bytes())
} }
func (v *BufferedSegmentWriter) FlushWithoutLock() { func (v *BufferedSegmentWriter) FlushWithoutLock() {
v.writer.Write(v.buffer) v.writer.Write(v.buffer.Bytes())
v.buffer = nil v.buffer.Clear()
} }
func (v *BufferedSegmentWriter) Flush() { func (v *BufferedSegmentWriter) Flush() {
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
if v.buffer.Len() == 0 { if v.buffer.IsEmpty() {
return return
} }
v.FlushWithoutLock() v.FlushWithoutLock()
} }
type AuthenticationWriter struct {
Authenticator internet.Authenticator
Writer io.Writer
Config *Config
}
func (v *AuthenticationWriter) Write(payload *alloc.Buffer) error {
defer payload.Release()
v.Authenticator.Seal(payload)
_, err := v.Writer.Write(payload.Bytes())
return err
}
func (v *AuthenticationWriter) Release() {}
func (v *AuthenticationWriter) Mtu() uint32 {
return v.Config.Mtu.GetValue() - uint32(v.Authenticator.Overhead())
}

View File

@ -24,6 +24,7 @@ const (
type Segment interface { type Segment interface {
common.Releasable common.Releasable
Conversation() uint16 Conversation() uint16
Command() Command
ByteSize() int ByteSize() int
Bytes() alloc.BytesWriter Bytes() alloc.BytesWriter
} }
@ -52,6 +53,10 @@ func (v *DataSegment) Conversation() uint16 {
return v.Conv return v.Conv
} }
func (v *DataSegment) Command() Command {
return CommandData
}
func (v *DataSegment) SetData(b []byte) { func (v *DataSegment) SetData(b []byte) {
if v.Data == nil { if v.Data == nil {
v.Data = alloc.NewSmallBuffer() v.Data = alloc.NewSmallBuffer()
@ -104,6 +109,10 @@ func (v *AckSegment) Conversation() uint16 {
return v.Conv return v.Conv
} }
func (v *AckSegment) Command() Command {
return CommandACK
}
func (v *AckSegment) PutTimestamp(timestamp uint32) { func (v *AckSegment) PutTimestamp(timestamp uint32) {
if timestamp-v.Timestamp < 0x7FFFFFFF { if timestamp-v.Timestamp < 0x7FFFFFFF {
v.Timestamp = timestamp v.Timestamp = timestamp
@ -144,7 +153,7 @@ func (v *AckSegment) Release() {
type CmdOnlySegment struct { type CmdOnlySegment struct {
Conv uint16 Conv uint16
Command Command Cmd Command
Option SegmentOption Option SegmentOption
SendingNext uint32 SendingNext uint32
ReceivinNext uint32 ReceivinNext uint32
@ -159,6 +168,10 @@ func (v *CmdOnlySegment) Conversation() uint16 {
return v.Conv return v.Conv
} }
func (v *CmdOnlySegment) Command() Command {
return v.Cmd
}
func (v *CmdOnlySegment) ByteSize() int { func (v *CmdOnlySegment) ByteSize() int {
return 2 + 1 + 1 + 4 + 4 + 4 return 2 + 1 + 1 + 4 + 4 + 4
} }
@ -166,7 +179,7 @@ func (v *CmdOnlySegment) ByteSize() int {
func (v *CmdOnlySegment) Bytes() alloc.BytesWriter { func (v *CmdOnlySegment) Bytes() alloc.BytesWriter {
return func(b []byte) int { return func(b []byte) int {
b = serial.Uint16ToBytes(v.Conv, b[:0]) b = serial.Uint16ToBytes(v.Conv, b[:0])
b = append(b, byte(v.Command), byte(v.Option)) b = append(b, byte(v.Cmd), byte(v.Option))
b = serial.Uint32ToBytes(v.SendingNext, b) b = serial.Uint32ToBytes(v.SendingNext, b)
b = serial.Uint32ToBytes(v.ReceivinNext, b) b = serial.Uint32ToBytes(v.ReceivinNext, b)
b = serial.Uint32ToBytes(v.PeerRTO, b) b = serial.Uint32ToBytes(v.PeerRTO, b)
@ -250,7 +263,7 @@ func ReadSegment(buf []byte) (Segment, []byte) {
seg := NewCmdOnlySegment() seg := NewCmdOnlySegment()
seg.Conv = conv seg.Conv = conv
seg.Command = cmd seg.Cmd = cmd
seg.Option = opt seg.Option = opt
if len(buf) < 12 { if len(buf) < 12 {

View File

@ -79,7 +79,7 @@ func TestCmdSegment(t *testing.T) {
seg := &CmdOnlySegment{ seg := &CmdOnlySegment{
Conv: 1, Conv: 1,
Command: CommandPing, Cmd: CommandPing,
Option: SegmentOptionClose, Option: SegmentOptionClose,
SendingNext: 11, SendingNext: 11,
ReceivinNext: 13, ReceivinNext: 13,
@ -95,7 +95,7 @@ func TestCmdSegment(t *testing.T) {
iseg, _ := ReadSegment(bytes) iseg, _ := ReadSegment(bytes)
seg2 := iseg.(*CmdOnlySegment) seg2 := iseg.(*CmdOnlySegment)
assert.Uint16(seg2.Conv).Equals(seg.Conv) assert.Uint16(seg2.Conv).Equals(seg.Conv)
assert.Byte(byte(seg2.Command)).Equals(byte(seg.Command)) assert.Byte(byte(seg2.Command())).Equals(byte(seg.Command()))
assert.Byte(byte(seg2.Option)).Equals(byte(seg.Option)) assert.Byte(byte(seg2.Option)).Equals(byte(seg.Option))
assert.Uint32(seg2.SendingNext).Equals(seg.SendingNext) assert.Uint32(seg2.SendingNext).Equals(seg.SendingNext)
assert.Uint32(seg2.ReceivinNext).Equals(seg.ReceivinNext) assert.Uint32(seg2.ReceivinNext).Equals(seg.ReceivinNext)