diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go new file mode 100644 index 00000000..6de1b687 --- /dev/null +++ b/transport/internet/kcp/connection.go @@ -0,0 +1,406 @@ +package kcp + +import ( + "errors" + "io" + "net" + "sync" + "time" + + "github.com/v2ray/v2ray-core/common/alloc" + "github.com/v2ray/v2ray-core/common/log" +) + +var ( + errTimeout = errors.New("i/o timeout") + errBrokenPipe = errors.New("broken pipe") + errClosedListener = errors.New("Listener closed.") +) + +const ( + basePort = 20000 // minimum port for listening + maxPort = 65535 // maximum port for listening + defaultWndSize = 128 // default window size, in packet + mtuLimit = 4096 + rxQueueLimit = 8192 + rxFecLimit = 2048 + + headerSize = 2 + cmdData uint16 = 0 + cmdClose uint16 = 1 +) + +type Command byte + +var ( + CommandData Command = 0 + CommandTerminate Command = 1 +) + +type Option byte + +var ( + OptionClose Option = 1 +) + +type ConnState byte + +var ( + ConnStateActive ConnState = 0 + ConnStateReadyToClose ConnState = 1 + ConnStatePeerClosed ConnState = 2 + ConnStateClosed ConnState = 4 +) + +func nowMillisec() int64 { + now := time.Now() + return now.Unix()*1000 + int64(now.Nanosecond()/1000000) +} + +// UDPSession defines a KCP session implemented by UDP +type UDPSession struct { + sync.Mutex + state ConnState + kcp *KCP // the core ARQ + kcpAccess sync.Mutex + block Authenticator + needUpdate bool + local, remote net.Addr + rd time.Time // read deadline + wd time.Time // write deadline + chReadEvent chan struct{} + chWriteEvent chan struct{} + ackNoDelay bool + writer io.WriteCloser + since int64 +} + +// newUDPSession create a new udp session for client or server +func newUDPSession(conv uint32, writerCloser io.WriteCloser, local *net.UDPAddr, remote *net.UDPAddr, block Authenticator) *UDPSession { + sess := new(UDPSession) + sess.local = local + sess.chReadEvent = make(chan struct{}, 1) + sess.chWriteEvent = make(chan struct{}, 1) + sess.remote = remote + sess.block = block + sess.writer = writerCloser + sess.since = nowMillisec() + + mtu := uint32(effectiveConfig.Mtu - block.HeaderSize() - headerSize) + sess.kcp = NewKCP(conv, mtu, func(buf []byte, size int) { + log.Info(sess.local, " kcp output: ", buf[:size]) + if size >= IKCP_OVERHEAD { + ext := alloc.NewBuffer().Clear().Append(buf[:size]) + cmd := cmdData + opt := Option(0) + if sess.state == ConnStateReadyToClose { + opt = OptionClose + } + ext.Prepend([]byte{byte(cmd), byte(opt)}) + sess.output(ext) + } + }) + sess.kcp.WndSize(effectiveConfig.Sndwnd, effectiveConfig.Rcvwnd) + sess.kcp.NoDelay(1, 20, 2, 1) + sess.ackNoDelay = effectiveConfig.Acknodelay + sess.kcp.current = sess.Elapsed() + + go sess.updateTask() + + log.Info("Created KCP conn to ", sess.RemoteAddr()) + return sess +} + +func (this *UDPSession) Elapsed() uint32 { + return uint32(nowMillisec() - this.since) +} + +// Read implements the Conn Read method. +func (s *UDPSession) Read(b []byte) (int, error) { + if s.state == ConnStateReadyToClose || s.state == ConnStateClosed { + return 0, io.EOF + } + + for { + s.Lock() + if s.state == ConnStateReadyToClose || s.state == ConnStateClosed { + s.Unlock() + return 0, io.EOF + } + + if !s.rd.IsZero() { + if time.Now().After(s.rd) { + s.Unlock() + return 0, errTimeout + } + } + + nBytes := s.kcp.Recv(b) + if nBytes > 0 { + s.Unlock() + return nBytes, nil + } + + var timeout <-chan time.Time + if !s.rd.IsZero() { + delay := s.rd.Sub(time.Now()) + timeout = time.After(delay) + } + + s.Unlock() + select { + case <-s.chReadEvent: + case <-timeout: + return 0, errTimeout + } + } +} + +// Write implements the Conn Write method. +func (s *UDPSession) Write(b []byte) (int, error) { + log.Info("Trying to write ", len(b), " bytes. ", s.local) + if s.state == ConnStateReadyToClose || + s.state == ConnStatePeerClosed || + s.state == ConnStateClosed { + return 0, io.ErrClosedPipe + } + + for { + s.Lock() + if s.state == ConnStateReadyToClose || + s.state == ConnStatePeerClosed || + s.state == ConnStateClosed { + s.Unlock() + return 0, io.ErrClosedPipe + } + + if !s.wd.IsZero() { + if time.Now().After(s.wd) { // timeout + s.Unlock() + return 0, errTimeout + } + } + + if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) { + nBytes := len(b) + log.Info("Writing ", nBytes, " bytes.", s.local) + s.kcp.Send(b) + s.kcp.current = s.Elapsed() + s.kcp.flush() + s.Unlock() + return nBytes, nil + } + + var timeout <-chan time.Time + if !s.wd.IsZero() { + delay := s.wd.Sub(time.Now()) + timeout = time.After(delay) + } + s.Unlock() + + // wait for write event or timeout + select { + case <-s.chWriteEvent: + case <-timeout: + return 0, errTimeout + } + } +} + +func (this *UDPSession) Terminate() { + if this.state == ConnStateClosed { + return + } + this.Lock() + defer this.Unlock() + + this.state = ConnStateClosed + this.writer.Close() +} + +func (this *UDPSession) NotifyTermination() { + for i := 0; i < 16; i++ { + this.Lock() + if this.state == ConnStateClosed { + this.Unlock() + return + } + buffer := alloc.NewSmallBuffer().Clear() + buffer.AppendBytes(byte(CommandTerminate), byte(OptionClose), byte(0), byte(0), byte(0), byte(0)) + this.output(buffer) + time.Sleep(time.Second) + this.Unlock() + } + this.Terminate() +} + +// Close closes the connection. +func (s *UDPSession) Close() error { + log.Info("Closed ", s.local) + s.Lock() + defer s.Unlock() + + if s.state == ConnStateActive { + s.state = ConnStateReadyToClose + if s.kcp.WaitSnd() == 0 { + go s.NotifyTermination() + } + } + + if s.state == ConnStatePeerClosed { + go s.Terminate() + } + + return nil +} + +// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. +func (s *UDPSession) LocalAddr() net.Addr { + return s.local +} + +// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. +func (s *UDPSession) RemoteAddr() net.Addr { return s.remote } + +// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. +func (s *UDPSession) SetDeadline(t time.Time) error { + s.Lock() + defer s.Unlock() + s.rd = t + s.wd = t + return nil +} + +// SetReadDeadline implements the Conn SetReadDeadline method. +func (s *UDPSession) SetReadDeadline(t time.Time) error { + s.Lock() + defer s.Unlock() + s.rd = t + return nil +} + +// SetWriteDeadline implements the Conn SetWriteDeadline method. +func (s *UDPSession) SetWriteDeadline(t time.Time) error { + s.Lock() + defer s.Unlock() + s.wd = t + return nil +} + +func (s *UDPSession) output(payload *alloc.Buffer) { + defer payload.Release() + + if s.state == ConnStatePeerClosed || s.state == ConnStateClosed { + return + } + s.block.Seal(payload) + + s.writer.Write(payload.Value) +} + +// kcp update, input loop +func (s *UDPSession) updateTask() { + ticker := time.NewTicker(20 * time.Millisecond) + defer ticker.Stop() + + var nextupdate uint32 = 0 + for range ticker.C { + s.Lock() + if s.state == ConnStateClosed { + s.Unlock() + return + } + current := s.Elapsed() + if !s.needUpdate && nextupdate == 0 { + nextupdate = s.kcp.Check(current) + } + current = s.Elapsed() + if s.needUpdate || current >= nextupdate { + log.Info("Updating KCP: ", current, " addr ", s.LocalAddr()) + s.kcp.Update(current) + nextupdate = s.kcp.Check(current) + s.needUpdate = false + } + if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) { + s.notifyWriteEvent() + } + s.Unlock() + } +} + +func (s *UDPSession) notifyReadEvent() { + select { + case s.chReadEvent <- struct{}{}: + default: + } +} + +func (s *UDPSession) notifyWriteEvent() { + select { + case s.chWriteEvent <- struct{}{}: + default: + } +} + +func (this *UDPSession) MarkPeerClose() { + this.Lock() + defer this.Unlock() + if this.state == ConnStateReadyToClose { + this.state = ConnStateClosed + go this.Terminate() + return + } + if this.state == ConnStateActive { + this.state = ConnStatePeerClosed + } +} + +func (s *UDPSession) kcpInput(data []byte) { + cmd := Command(data[0]) + opt := Option(data[1]) + if cmd == CommandTerminate { + go s.Terminate() + return + } + if opt == OptionClose { + go s.MarkPeerClose() + } + s.kcpAccess.Lock() + s.kcp.current = s.Elapsed() + log.Info(s.local, " kcp input: ", data[2:]) + ret := s.kcp.Input(data[2:]) + log.Info("kcp input returns ", ret) + + if s.ackNoDelay { + s.kcp.current = s.Elapsed() + s.kcp.flush() + } else { + s.needUpdate = true + } + s.kcpAccess.Unlock() + s.notifyReadEvent() +} + +func (this *UDPSession) FetchInputFrom(conn net.Conn) { + go func() { + for { + payload := alloc.NewBuffer() + nBytes, err := conn.Read(payload.Value) + if err != nil { + return + } + payload.Slice(0, nBytes) + if this.block.Open(payload) { + log.Info("Client fetching ", payload.Len(), " bytes.") + this.kcpInput(payload.Value) + } + payload.Release() + } + }() +} + +func (this *UDPSession) Reusable() bool { + return false +} + +func (this *UDPSession) SetReusable(b bool) {} diff --git a/transport/internet/kcp/crypt.go b/transport/internet/kcp/crypt.go index 02e963ab..7d769474 100644 --- a/transport/internet/kcp/crypt.go +++ b/transport/internet/kcp/crypt.go @@ -1,23 +1,64 @@ package kcp -type BlockCrypt interface { +import ( + "hash/fnv" + + "github.com/v2ray/v2ray-core/common/alloc" + "github.com/v2ray/v2ray-core/common/serial" +) + +type Authenticator interface { + HeaderSize() int // Encrypt encrypts the whole block in src into dst. // Dst and src may point at the same memory. - Encrypt(dst, src []byte) + Seal(buffer *alloc.Buffer) // Decrypt decrypts the whole block in src into dst. // Dst and src may point at the same memory. - Decrypt(dst, src []byte) + Open(buffer *alloc.Buffer) bool +} + +type SimpleAuthenticator struct{} + +func NewSimpleAuthenticator() Authenticator { + return &SimpleAuthenticator{} } -// None Encryption -type NoneBlockCrypt struct { - xortbl []byte +func (this *SimpleAuthenticator) HeaderSize() int { + return 6 } -func NewNoneBlockCrypt(key []byte) (BlockCrypt, error) { - return new(NoneBlockCrypt), nil +func (this *SimpleAuthenticator) Seal(buffer *alloc.Buffer) { + var length uint16 = uint16(buffer.Len()) + buffer.Prepend(serial.Uint16ToBytes(length)) + fnvHash := fnv.New32a() + fnvHash.Write(buffer.Value) + + buffer.SliceBack(4) + fnvHash.Sum(buffer.Value[:0]) + + for i := 4; i < buffer.Len(); i++ { + buffer.Value[i] ^= buffer.Value[i-4] + } } -func (c *NoneBlockCrypt) Encrypt(dst, src []byte) {} -func (c *NoneBlockCrypt) Decrypt(dst, src []byte) {} +func (this *SimpleAuthenticator) Open(buffer *alloc.Buffer) bool { + for i := buffer.Len() - 1; i >= 4; i-- { + buffer.Value[i] ^= buffer.Value[i-4] + } + + fnvHash := fnv.New32a() + fnvHash.Write(buffer.Value[4:]) + if serial.BytesToUint32(buffer.Value[:4]) != fnvHash.Sum32() { + return false + } + + length := serial.BytesToUint16(buffer.Value[4:6]) + if buffer.Len()-6 != int(length) { + return false + } + + buffer.SliceFrom(6) + + return true +} diff --git a/transport/internet/kcp/crypt_test.go b/transport/internet/kcp/crypt_test.go new file mode 100644 index 00000000..e57c28e4 --- /dev/null +++ b/transport/internet/kcp/crypt_test.go @@ -0,0 +1,22 @@ +package kcp_test + +import ( + "testing" + + "github.com/v2ray/v2ray-core/common/alloc" + "github.com/v2ray/v2ray-core/testing/assert" + . "github.com/v2ray/v2ray-core/transport/internet/kcp" +) + +func TestSimpleAuthenticator(t *testing.T) { + assert := assert.On(t) + + buffer := alloc.NewBuffer().Clear() + buffer.AppendBytes('a', 'b', 'c', 'd', 'e', 'f', 'g') + + auth := NewSimpleAuthenticator() + auth.Seal(buffer) + + assert.Bool(auth.Open(buffer)).IsTrue() + assert.String(buffer.String()).Equals("abcdefg") +} diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index a66fd647..5b22b63e 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -5,6 +5,7 @@ import ( "math/rand" "net" + "github.com/v2ray/v2ray-core/common/log" v2net "github.com/v2ray/v2ray-core/common/net" "github.com/v2ray/v2ray-core/transport/internet" ) @@ -14,37 +15,18 @@ var ( ) func DialKCP(src v2net.Address, dest v2net.Destination) (internet.Connection, error) { - var ip net.IP - if dest.Address().IsDomain() { - ips, err := net.LookupIP(dest.Address().Domain()) - if err != nil { - return nil, err - } - if len(ips) == 0 { - return nil, ErrUnknownDestination - } - ip = ips[0] - } else { - ip = dest.Address().IP() - } - udpAddr := &net.UDPAddr{ - IP: ip, - Port: int(dest.Port()), - } - - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{}) + log.Info("Dialling KCP to ", dest) + udpDest := v2net.UDPDestination(dest.Address(), dest.Port()) + conn, err := internet.DialToDest(src, udpDest) if err != nil { return nil, err } - cpip, _ := NewNoneBlockCrypt(nil) - session := newUDPSession(rand.Uint32(), nil, udpConn, udpAddr, cpip) - kcvn := &KCPVconn{hc: session} - err = kcvn.ApplyConf() - if err != nil { - return nil, err - } - return kcvn, nil + cpip := NewSimpleAuthenticator() + session := newUDPSession(rand.Uint32(), conn, conn.LocalAddr().(*net.UDPAddr), conn.RemoteAddr().(*net.UDPAddr), cpip) + session.FetchInputFrom(conn) + + return session, nil } func init() { diff --git a/transport/internet/kcp/kcp.go b/transport/internet/kcp/kcp.go index b1f58086..42f6e66b 100644 --- a/transport/internet/kcp/kcp.go +++ b/transport/internet/kcp/kcp.go @@ -22,7 +22,7 @@ const ( IKCP_ASK_TELL = 2 // need to send IKCP_CMD_WINS IKCP_WND_SND = 32 IKCP_WND_RCV = 32 - IKCP_MTU_DEF = 1400 + IKCP_MTU_DEF = 1350 IKCP_ACK_FAST = 3 IKCP_INTERVAL = 100 IKCP_OVERHEAD = 24 @@ -156,13 +156,13 @@ type KCP struct { // NewKCP create a new kcp control object, 'conv' must equal in two endpoint // from the same connection. -func NewKCP(conv uint32, output Output) *KCP { +func NewKCP(conv uint32, mtu uint32, output Output) *KCP { kcp := new(KCP) kcp.conv = conv kcp.snd_wnd = IKCP_WND_SND kcp.rcv_wnd = IKCP_WND_RCV kcp.rmt_wnd = IKCP_WND_RCV - kcp.mtu = IKCP_MTU_DEF + kcp.mtu = mtu kcp.mss = kcp.mtu - IKCP_OVERHEAD kcp.buffer = make([]byte, (kcp.mtu+IKCP_OVERHEAD)*3) kcp.rx_rto = IKCP_RTO_DEF @@ -206,14 +206,14 @@ func (kcp *KCP) Recv(buffer []byte) (n int) { return -1 } - peeksize := kcp.PeekSize() - if peeksize < 0 { - return -2 - } + //peeksize := kcp.PeekSize() + //if peeksize < 0 { + // return -2 + //} - if peeksize > len(buffer) { - return -3 - } + //if peeksize > len(buffer) { + // return -3 + //} var fast_recover bool if len(kcp.rcv_queue) >= int(kcp.rcv_wnd) { @@ -224,13 +224,13 @@ func (kcp *KCP) Recv(buffer []byte) (n int) { count := 0 for k := range kcp.rcv_queue { seg := &kcp.rcv_queue[k] + if len(seg.data) > len(buffer) { + break + } copy(buffer, seg.data) buffer = buffer[len(seg.data):] n += len(seg.data) count++ - if seg.frg == 0 { - break - } } kcp.rcv_queue = kcp.rcv_queue[count:] @@ -901,3 +901,7 @@ func (kcp *KCP) WndSize(sndwnd, rcvwnd int) int { func (kcp *KCP) WaitSnd() int { return len(kcp.snd_buf) + len(kcp.snd_queue) } + +func (kcp *KCP) WaitRcv() int { + return len(kcp.rcv_buf) + len(kcp.rcv_queue) +} diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go new file mode 100644 index 00000000..386727c8 --- /dev/null +++ b/transport/internet/kcp/listener.go @@ -0,0 +1,160 @@ +package kcp + +import ( + "encoding/binary" + "net" + "sync" + "time" + + "github.com/v2ray/v2ray-core/common/alloc" + "github.com/v2ray/v2ray-core/common/log" + v2net "github.com/v2ray/v2ray-core/common/net" + "github.com/v2ray/v2ray-core/transport/internet" + "github.com/v2ray/v2ray-core/transport/internet/udp" +) + +// Listener defines a server listening for connections +type Listener struct { + sync.Mutex + running bool + block Authenticator + sessions map[string]*UDPSession + awaitingConns chan *UDPSession + hub *udp.UDPHub + localAddr *net.UDPAddr +} + +func NewListener(address v2net.Address, port v2net.Port) (*Listener, error) { + log.Info("Creating listener on ", address, ":", port) + l := &Listener{ + block: NewSimpleAuthenticator(), + sessions: make(map[string]*UDPSession), + awaitingConns: make(chan *UDPSession, 64), + localAddr: &net.UDPAddr{ + IP: address.IP(), + Port: int(port), + }, + running: true, + } + hub, err := udp.ListenUDP(address, port, l.OnReceive) + if err != nil { + return nil, err + } + l.hub = hub + log.Info("Listener created.") + return l, nil +} + +func (this *Listener) OnReceive(payload *alloc.Buffer, src v2net.Destination) { + log.Info("Listener on receive from ", src) + defer payload.Release() + + if valid := this.block.Open(payload); !valid { + log.Info("Listern discarding invalid payload.") + return + } + if !this.running { + return + } + this.Lock() + defer this.Unlock() + if !this.running { + return + } + srcAddrStr := src.NetAddr() + conn, found := this.sessions[srcAddrStr] + if !found { + conv := binary.LittleEndian.Uint32(payload.Value[2:6]) + writer := &Writer{ + hub: this.hub, + dest: src, + listener: this, + } + srcAddr := &net.UDPAddr{ + IP: src.Address().IP(), + Port: int(src.Port()), + } + log.Info("Listener creating new connection.") + conn = newUDPSession(conv, writer, this.localAddr, srcAddr, this.block) + select { + case this.awaitingConns <- conn: + case <-time.After(time.Second * 5): + conn.Close() + return + } + this.sessions[srcAddrStr] = conn + } + conn.kcpInput(payload.Value) +} + +func (this *Listener) Remove(dest string) { + if !this.running { + return + } + this.Lock() + defer this.Unlock() + if !this.running { + return + } + delete(this.sessions, dest) +} + +// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn. +func (this *Listener) Accept() (internet.Connection, error) { + for { + if !this.running { + return nil, errClosedListener + } + select { + case conn := <-this.awaitingConns: + log.Info("Accepting connection from ", conn.RemoteAddr()) + return conn, nil + case <-time.After(time.Second): + + } + } +} + +// Close stops listening on the UDP address. Already Accepted connections are not closed. +func (this *Listener) Close() error { + if !this.running { + return errClosedListener + } + this.Lock() + defer this.Unlock() + + this.running = false + close(this.awaitingConns) + this.hub.Close() + + return nil +} + +// Addr returns the listener's network address, The Addr returned is shared by all invocations of Addr, so do not modify it. +func (this *Listener) Addr() net.Addr { + return this.localAddr +} + +type Writer struct { + dest v2net.Destination + hub *udp.UDPHub + listener *Listener +} + +func (this *Writer) Write(payload []byte) (int, error) { + log.Info("Writer writing to ", this.dest, " with ", len(payload), " bytes.") + return this.hub.WriteTo(payload, this.dest) +} + +func (this *Writer) Close() error { + this.listener.Remove(this.dest.NetAddr()) + return nil +} + +func ListenKCP(address v2net.Address, port v2net.Port) (internet.Listener, error) { + return NewListener(address, port) +} + +func init() { + internet.KCPListenFunc = ListenKCP +} diff --git a/transport/internet/kcp/sess.go b/transport/internet/kcp/sess.go deleted file mode 100644 index 7a53e236..00000000 --- a/transport/internet/kcp/sess.go +++ /dev/null @@ -1,563 +0,0 @@ -package kcp - -import ( - crand "crypto/rand" - "encoding/binary" - "errors" - "hash/crc32" - "io" - "log" - "math/rand" - "net" - "sync" - "time" - - "golang.org/x/net/ipv4" -) - -var ( - errTimeout = errors.New("i/o timeout") - errBrokenPipe = errors.New("broken pipe") -) - -const ( - basePort = 20000 // minimum port for listening - maxPort = 65535 // maximum port for listening - defaultWndSize = 128 // default window size, in packet - otpSize = 16 // magic number - crcSize = 4 // 4bytes packet checksum - cryptHeaderSize = otpSize + crcSize - connTimeout = 60 * time.Second - mtuLimit = 4096 - rxQueueLimit = 8192 - rxFecLimit = 2048 -) - -type ( - // UDPSession defines a KCP session implemented by UDP - UDPSession struct { - kcp *KCP // the core ARQ - conn *net.UDPConn // the underlying UDP socket - block BlockCrypt - needUpdate bool - l *Listener // point to server listener if it's a server socket - local, remote net.Addr - rd time.Time // read deadline - wd time.Time // write deadline - sockbuff []byte // kcp receiving is based on packet, I turn it into stream - die chan struct{} - isClosed bool - mu sync.Mutex - chReadEvent chan struct{} - chWriteEvent chan struct{} - chTicker chan time.Time - chUDPOutput chan []byte - headerSize int - lastInputTs time.Time - ackNoDelay bool - } -) - -// newUDPSession create a new udp session for client or server -func newUDPSession(conv uint32, l *Listener, conn *net.UDPConn, remote *net.UDPAddr, block BlockCrypt) *UDPSession { - sess := new(UDPSession) - sess.chTicker = make(chan time.Time, 1) - sess.chUDPOutput = make(chan []byte, rxQueueLimit) - sess.die = make(chan struct{}) - sess.local = conn.LocalAddr() - sess.chReadEvent = make(chan struct{}, 1) - sess.chWriteEvent = make(chan struct{}, 1) - sess.remote = remote - sess.conn = conn - sess.l = l - sess.block = block - sess.lastInputTs = time.Now() - - // caculate header size - if sess.block != nil { - sess.headerSize += cryptHeaderSize - } - - sess.kcp = NewKCP(conv, func(buf []byte, size int) { - if size >= IKCP_OVERHEAD { - ext := make([]byte, sess.headerSize+size) - copy(ext[sess.headerSize:], buf) - sess.chUDPOutput <- ext - } - }) - sess.kcp.WndSize(defaultWndSize, defaultWndSize) - sess.kcp.SetMtu(IKCP_MTU_DEF - sess.headerSize) - - go sess.updateTask() - go sess.outputTask() - if l == nil { // it's a client connection - go sess.readLoop() - } - - return sess -} - -// Read implements the Conn Read method. -func (s *UDPSession) Read(b []byte) (n int, err error) { - for { - s.mu.Lock() - if len(s.sockbuff) > 0 { // copy from buffer - n = copy(b, s.sockbuff) - s.sockbuff = s.sockbuff[n:] - s.mu.Unlock() - return n, nil - } - - if s.isClosed { - s.mu.Unlock() - return 0, errBrokenPipe - } - - if !s.rd.IsZero() { - if time.Now().After(s.rd) { // timeout - s.mu.Unlock() - return 0, errTimeout - } - } - - if n := s.kcp.PeekSize(); n > 0 { // data arrived - if len(b) >= n { - s.kcp.Recv(b) - } else { - buf := make([]byte, n) - s.kcp.Recv(buf) - n = copy(b, buf) - s.sockbuff = buf[n:] // store remaining bytes into sockbuff for next read - } - s.mu.Unlock() - return n, nil - } - - var timeout <-chan time.Time - if !s.rd.IsZero() { - delay := s.rd.Sub(time.Now()) - timeout = time.After(delay) - } - s.mu.Unlock() - - // wait for read event or timeout - select { - case <-s.chReadEvent: - case <-timeout: - case <-s.die: - } - } -} - -// Write implements the Conn Write method. -func (s *UDPSession) Write(b []byte) (n int, err error) { - for { - s.mu.Lock() - if s.isClosed { - s.mu.Unlock() - return 0, errBrokenPipe - } - - if !s.wd.IsZero() { - if time.Now().After(s.wd) { // timeout - s.mu.Unlock() - return 0, errTimeout - } - } - - if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) { - n = len(b) - max := s.kcp.mss << 8 - for { - if len(b) <= int(max) { // in most cases - s.kcp.Send(b) - break - } else { - s.kcp.Send(b[:max]) - b = b[max:] - } - } - s.kcp.current = currentMs() - s.kcp.flush() - s.mu.Unlock() - return n, nil - } - - var timeout <-chan time.Time - if !s.wd.IsZero() { - delay := s.wd.Sub(time.Now()) - timeout = time.After(delay) - } - s.mu.Unlock() - - // wait for write event or timeout - select { - case <-s.chWriteEvent: - case <-timeout: - case <-s.die: - } - } -} - -// Close closes the connection. -func (s *UDPSession) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.isClosed { - return errBrokenPipe - } - close(s.die) - s.isClosed = true - if s.l == nil { // client socket close - s.conn.Close() - } - - return nil -} - -// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. -func (s *UDPSession) LocalAddr() net.Addr { - return s.local -} - -// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. -func (s *UDPSession) RemoteAddr() net.Addr { return s.remote } - -// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. -func (s *UDPSession) SetDeadline(t time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - s.rd = t - s.wd = t - return nil -} - -// SetReadDeadline implements the Conn SetReadDeadline method. -func (s *UDPSession) SetReadDeadline(t time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - s.rd = t - return nil -} - -// SetWriteDeadline implements the Conn SetWriteDeadline method. -func (s *UDPSession) SetWriteDeadline(t time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - s.wd = t - return nil -} - -// SetWindowSize set maximum window size -func (s *UDPSession) SetWindowSize(sndwnd, rcvwnd int) { - s.mu.Lock() - defer s.mu.Unlock() - s.kcp.WndSize(sndwnd, rcvwnd) -} - -// SetMtu sets the maximum transmission unit -func (s *UDPSession) SetMtu(mtu int) { - s.mu.Lock() - defer s.mu.Unlock() - s.kcp.SetMtu(mtu - s.headerSize) -} - -// SetACKNoDelay changes ack flush option, set true to flush ack immediately, -func (s *UDPSession) SetACKNoDelay(nodelay bool) { - s.mu.Lock() - defer s.mu.Unlock() - s.ackNoDelay = nodelay -} - -// SetNoDelay calls nodelay() of kcp -func (s *UDPSession) SetNoDelay(nodelay, interval, resend, nc int) { - s.mu.Lock() - defer s.mu.Unlock() - s.kcp.NoDelay(nodelay, interval, resend, nc) -} - -// SetDSCP sets the DSCP field of IP header -func (s *UDPSession) SetDSCP(tos int) { - s.mu.Lock() - defer s.mu.Unlock() - if err := ipv4.NewConn(s.conn).SetTOS(tos << 2); err != nil { - log.Println("set tos:", err) - } -} - -func (s *UDPSession) outputTask() { - // ping - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - for { - select { - case ext := <-s.chUDPOutput: - if s.block != nil { - io.ReadFull(crand.Reader, ext[:otpSize]) // OTP - checksum := crc32.ChecksumIEEE(ext[cryptHeaderSize:]) - binary.LittleEndian.PutUint32(ext[otpSize:], checksum) - s.block.Encrypt(ext, ext) - } - - //if rand.Intn(100) < 80 { - n, err := s.conn.WriteTo(ext, s.remote) - if err != nil { - log.Println(err, n) - } - //} - - case <-ticker.C: - sz := rand.Intn(IKCP_MTU_DEF - s.headerSize - IKCP_OVERHEAD) - sz += s.headerSize + IKCP_OVERHEAD - ping := make([]byte, sz) - io.ReadFull(crand.Reader, ping) - if s.block != nil { - checksum := crc32.ChecksumIEEE(ping[cryptHeaderSize:]) - binary.LittleEndian.PutUint32(ping[otpSize:], checksum) - s.block.Encrypt(ping, ping) - } - - n, err := s.conn.WriteTo(ping, s.remote) - if err != nil { - log.Println(err, n) - } - case <-s.die: - return - } - } -} - -// kcp update, input loop -func (s *UDPSession) updateTask() { - var tc <-chan time.Time - if s.l == nil { // client - ticker := time.NewTicker(10 * time.Millisecond) - tc = ticker.C - defer ticker.Stop() - } else { - tc = s.chTicker - } - - var nextupdate uint32 - for { - select { - case <-tc: - s.mu.Lock() - current := currentMs() - if current >= nextupdate || s.needUpdate { - s.kcp.Update(current) - nextupdate = s.kcp.Check(current) - } - if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) { - s.notifyWriteEvent() - } - s.needUpdate = false - s.mu.Unlock() - case <-s.die: - if s.l != nil { // has listener - s.l.chDeadlinks <- s.remote - } - return - } - } -} - -// GetConv gets conversation id of a session -func (s *UDPSession) GetConv() uint32 { - return s.kcp.conv -} - -func (s *UDPSession) notifyReadEvent() { - select { - case s.chReadEvent <- struct{}{}: - default: - } -} - -func (s *UDPSession) notifyWriteEvent() { - select { - case s.chWriteEvent <- struct{}{}: - default: - } -} - -func (s *UDPSession) kcpInput(data []byte) { - now := time.Now() - if now.Sub(s.lastInputTs) > connTimeout { - s.Close() - return - } - s.lastInputTs = now - - s.mu.Lock() - s.kcp.current = currentMs() - s.kcp.Input(data) - - if s.ackNoDelay { - s.kcp.current = currentMs() - s.kcp.flush() - } else { - s.needUpdate = true - } - s.mu.Unlock() - s.notifyReadEvent() -} - -func (s *UDPSession) receiver(ch chan []byte) { - for { - data := make([]byte, mtuLimit) - if n, _, err := s.conn.ReadFromUDP(data); err == nil && n >= s.headerSize+IKCP_OVERHEAD { - ch <- data[:n] - } else if err != nil { - return - } - } -} - -// read loop for client session -func (s *UDPSession) readLoop() { - chPacket := make(chan []byte, rxQueueLimit) - go s.receiver(chPacket) - - for { - select { - case data := <-chPacket: - dataValid := false - if s.block != nil { - s.block.Decrypt(data, data) - data = data[otpSize:] - checksum := crc32.ChecksumIEEE(data[crcSize:]) - if checksum == binary.LittleEndian.Uint32(data) { - data = data[crcSize:] - dataValid = true - } - } else if s.block == nil { - dataValid = true - } - - if dataValid { - s.kcpInput(data) - } - case <-s.die: - return - } - } -} - -type ( - // Listener defines a server listening for connections - Listener struct { - block BlockCrypt - conn *net.UDPConn - sessions map[string]*UDPSession - chAccepts chan *UDPSession - chDeadlinks chan net.Addr - headerSize int - die chan struct{} - } - - packet struct { - from *net.UDPAddr - data []byte - } -) - -// monitor incoming data for all connections of server -func (l *Listener) monitor() { - chPacket := make(chan packet, rxQueueLimit) - go l.receiver(chPacket) - ticker := time.NewTicker(10 * time.Millisecond) - defer ticker.Stop() - for { - select { - case p := <-chPacket: - data := p.data - from := p.from - dataValid := false - if l.block != nil { - l.block.Decrypt(data, data) - data = data[otpSize:] - checksum := crc32.ChecksumIEEE(data[crcSize:]) - if checksum == binary.LittleEndian.Uint32(data) { - data = data[crcSize:] - dataValid = true - } - } else if l.block == nil { - dataValid = true - } - - if dataValid { - addr := from.String() - s, ok := l.sessions[addr] - if !ok { // new session - var conv uint32 - convValid := false - - conv = binary.LittleEndian.Uint32(data) - convValid = true - - if convValid { - s := newUDPSession(conv, l, l.conn, from, l.block) - s.kcpInput(data) - l.sessions[addr] = s - l.chAccepts <- s - } - } else { - s.kcpInput(data) - } - } - case deadlink := <-l.chDeadlinks: - delete(l.sessions, deadlink.String()) - case <-l.die: - return - case <-ticker.C: - now := time.Now() - for _, s := range l.sessions { - select { - case s.chTicker <- now: - default: - } - } - } - } -} - -func (l *Listener) receiver(ch chan packet) { - for { - data := make([]byte, mtuLimit) - if n, from, err := l.conn.ReadFromUDP(data); err == nil && n >= l.headerSize+IKCP_OVERHEAD { - ch <- packet{from, data[:n]} - } else if err != nil { - return - } - } -} - -// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn. -func (l *Listener) Accept() (*UDPSession, error) { - select { - case c := <-l.chAccepts: - return c, nil - case <-l.die: - return nil, errors.New("listener stopped") - } -} - -// Close stops listening on the UDP address. Already Accepted connections are not closed. -func (l *Listener) Close() error { - if err := l.conn.Close(); err == nil { - close(l.die) - return nil - } else { - return err - } -} - -// Addr returns the listener's network address, The Addr returned is shared by all invocations of Addr, so do not modify it. -func (l *Listener) Addr() net.Addr { - return l.conn.LocalAddr() -} - -func currentMs() uint32 { - return uint32(time.Now().UnixNano() / int64(time.Millisecond)) -} diff --git a/transport/internet/kcp/session.go b/transport/internet/kcp/session.go deleted file mode 100644 index ef32dba2..00000000 --- a/transport/internet/kcp/session.go +++ /dev/null @@ -1,191 +0,0 @@ -package kcp - -import ( - "errors" - "net" - "time" - - v2net "github.com/v2ray/v2ray-core/common/net" - "github.com/v2ray/v2ray-core/transport/internet" -) - -type KCPVlistener struct { - lst *Listener - previousSocketid map[int]uint32 - previousSocketid_mapid int -} - -/*Accept Accept a KCP connection -Since KCP is stateless, if package deliver after it was closed, -It could be reconized as a new connection and call accept. -If we can detect that the connection is of such a kind, -we will discard that conn. -*/ -func (kvl *KCPVlistener) Accept() (internet.Connection, error) { - conn, err := kvl.lst.Accept() - if err != nil { - return nil, err - } - - if kvl.previousSocketid == nil { - kvl.previousSocketid = make(map[int]uint32) - } - - var badbit bool = false - - for _, key := range kvl.previousSocketid { - if key == conn.GetConv() { - badbit = true - } - } - if badbit { - conn.Close() - return nil, errors.New("KCP:ConnDup, Don't worry~") - } else { - kvl.previousSocketid_mapid++ - kvl.previousSocketid[kvl.previousSocketid_mapid] = conn.GetConv() - /* - Here we assume that count(connection) < 512 - This won't always true. - More work might be necessary to deal with this in a better way. - */ - if kvl.previousSocketid_mapid >= 512 { - delete(kvl.previousSocketid, kvl.previousSocketid_mapid-512) - } - } - - kcv := &KCPVconn{hc: conn} - err = kcv.ApplyConf() - if err != nil { - return nil, err - } - return kcv, nil -} - -func (kvl *KCPVlistener) Close() error { - return kvl.lst.Close() -} - -func (kvl *KCPVlistener) Addr() net.Addr { - return kvl.lst.Addr() -} - -type KCPVconn struct { - hc *UDPSession - conntokeep time.Time -} - -func (kcpvc *KCPVconn) Read(b []byte) (int, error) { - ifb := time.Now().Add(time.Duration(effectiveConfig.ReadTimeout) * time.Second) - if ifb.After(kcpvc.conntokeep) { - kcpvc.conntokeep = ifb - } - kcpvc.hc.SetDeadline(kcpvc.conntokeep) - return kcpvc.hc.Read(b) -} - -func (kcpvc *KCPVconn) Write(b []byte) (int, error) { - ifb := time.Now().Add(time.Duration(effectiveConfig.WriteTimeout) * time.Second) - if ifb.After(kcpvc.conntokeep) { - kcpvc.conntokeep = ifb - } - kcpvc.hc.SetDeadline(kcpvc.conntokeep) - return kcpvc.hc.Write(b) -} - -/*ApplyConf will apply kcpvc.conf to current Socket - -It is recommmanded to call this func once and only once -*/ -func (kcpvc *KCPVconn) ApplyConf() error { - nodelay, interval, resend, nc := 0, 40, 0, 0 - switch effectiveConfig.Mode { - case "normal": - nodelay, interval, resend, nc = 0, 30, 2, 1 - case "fast": - nodelay, interval, resend, nc = 0, 20, 2, 1 - case "fast2": - nodelay, interval, resend, nc = 1, 20, 2, 1 - case "fast3": - nodelay, interval, resend, nc = 1, 10, 2, 1 - } - - kcpvc.hc.SetNoDelay(nodelay, interval, resend, nc) - kcpvc.hc.SetWindowSize(effectiveConfig.Sndwnd, effectiveConfig.Rcvwnd) - kcpvc.hc.SetMtu(effectiveConfig.Mtu) - kcpvc.hc.SetACKNoDelay(effectiveConfig.Acknodelay) - kcpvc.hc.SetDSCP(effectiveConfig.Dscp) - return nil -} - -/*Close Close the current conn -We have to delay the close of Socket for a few second -or the VMess EOF can be too late to send. -*/ -func (kcpvc *KCPVconn) Close() error { - go func() { - time.Sleep(2000 * time.Millisecond) - kcpvc.hc.Close() - }() - return nil -} - -func (kcpvc *KCPVconn) LocalAddr() net.Addr { - return kcpvc.hc.LocalAddr() -} - -func (kcpvc *KCPVconn) RemoteAddr() net.Addr { - return kcpvc.hc.RemoteAddr() -} - -func (kcpvc *KCPVconn) SetDeadline(t time.Time) error { - return kcpvc.hc.SetDeadline(t) -} - -func (kcpvc *KCPVconn) SetReadDeadline(t time.Time) error { - return kcpvc.hc.SetReadDeadline(t) -} - -func (kcpvc *KCPVconn) SetWriteDeadline(t time.Time) error { - return kcpvc.hc.SetWriteDeadline(t) -} - -func (this *KCPVconn) Reusable() bool { - return false -} - -func (this *KCPVconn) SetReusable(b bool) { - -} - -func ListenKCP(address v2net.Address, port v2net.Port) (internet.Listener, error) { - conn, err := net.ListenUDP("udp", &net.UDPAddr{ - IP: address.IP(), - Port: int(port), - }) - if err != nil { - return nil, err - } - - block, _ := NewNoneBlockCrypt(nil) - - l := new(Listener) - l.conn = conn - l.sessions = make(map[string]*UDPSession) - l.chAccepts = make(chan *UDPSession, 1024) - l.chDeadlinks = make(chan net.Addr, 1024) - l.die = make(chan struct{}) - l.block = block - - // caculate header size - if l.block != nil { - l.headerSize += cryptHeaderSize - } - - go l.monitor() - return &KCPVlistener{lst: l}, nil -} - -func init() { - internet.KCPListenFunc = ListenKCP -}