diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 0b15a3a0..ad130fb5 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -23,19 +23,6 @@ const ( headerSize uint32 = 2 ) -type Command byte - -var ( - CommandData Command = 0 - CommandTerminate Command = 1 -) - -type Option byte - -var ( - OptionClose Option = 1 -) - type ConnState byte var ( @@ -69,7 +56,7 @@ type Connection struct { } // NewConnection create a new KCP connection between local and remote. -func NewConnection(conv uint32, writerCloser io.WriteCloser, local *net.UDPAddr, remote *net.UDPAddr, block Authenticator) *Connection { +func NewConnection(conv uint16, writerCloser io.WriteCloser, local *net.UDPAddr, remote *net.UDPAddr, block Authenticator) *Connection { conn := new(Connection) conn.local = local conn.chReadEvent = make(chan struct{}, 1) @@ -79,8 +66,12 @@ func NewConnection(conv uint32, writerCloser io.WriteCloser, local *net.UDPAddr, conn.since = nowMillisec() conn.writeBufferSize = effectiveConfig.WriteBuffer / effectiveConfig.Mtu + authWriter := &AuthenticationWriter{ + Authenticator: block, + Writer: writerCloser, + } mtu := effectiveConfig.Mtu - uint32(block.HeaderSize()) - headerSize - conn.kcp = NewKCP(conv, mtu, effectiveConfig.GetSendingWindowSize(), effectiveConfig.GetReceivingWindowSize(), conn.writeBufferSize, conn.output) + conn.kcp = NewKCP(conv, mtu, effectiveConfig.GetSendingWindowSize(), effectiveConfig.GetReceivingWindowSize(), conn.writeBufferSize, authWriter) conn.kcp.NoDelay(effectiveConfig.Tti, 2, effectiveConfig.Congestion) conn.kcp.current = conn.Elapsed() @@ -95,13 +86,19 @@ func (this *Connection) Elapsed() uint32 { // Read implements the Conn Read method. func (this *Connection) Read(b []byte) (int, error) { - if this == nil || this.state == ConnStateReadyToClose || this.state == ConnStateClosed { + if this == nil || + this.kcp.state == StateReadyToClose || + this.kcp.state == StateTerminating || + this.kcp.state == StateTerminated { return 0, io.EOF } for { this.RLock() - if this.state == ConnStateReadyToClose || this.state == ConnStateClosed { + if this == nil || + this.kcp.state == StateReadyToClose || + this.kcp.state == StateTerminating || + this.kcp.state == StateTerminated { this.RUnlock() return 0, io.EOF } @@ -127,19 +124,14 @@ func (this *Connection) Read(b []byte) (int, error) { // Write implements the Conn Write method. func (this *Connection) Write(b []byte) (int, error) { - if this == nil || - this.state == ConnStateReadyToClose || - this.state == ConnStatePeerClosed || - this.state == ConnStateClosed { + if this == nil || this.kcp.state != StateActive { return 0, io.ErrClosedPipe } totalWritten := 0 for { this.RLock() - if this.state == ConnStateReadyToClose || - this.state == ConnStatePeerClosed || - this.state == ConnStateClosed { + if this == nil || this.kcp.state != StateActive { this.RUnlock() return totalWritten, io.ErrClosedPipe } @@ -166,72 +158,21 @@ func (this *Connection) Write(b []byte) (int, error) { } } -func (this *Connection) Terminate() { - if this == nil || this.state == ConnStateClosed { - return - } - this.Lock() - defer this.Unlock() - if this.state == ConnStateClosed { - return - } - - this.state = ConnStateClosed - this.writer.Close() -} - -func (this *Connection) NotifyTermination() { - for i := 0; i < 16; i++ { - this.RLock() - if this.state == ConnStateClosed { - this.RUnlock() - break - } - this.RUnlock() - buffer := alloc.NewSmallBuffer().Clear() - buffer.AppendBytes(byte(CommandTerminate), byte(OptionClose), byte(0), byte(0), byte(0), byte(0)) - this.outputBuffer(buffer) - - time.Sleep(time.Second) - - } - this.Terminate() -} - -func (this *Connection) ForceTimeout() { - if this == nil { - return - } - for i := 0; i < 5; i++ { - if this.state == ConnStateClosed { - return - } - time.Sleep(time.Minute) - } - go this.terminateOnce.Do(this.NotifyTermination) -} - // Close closes the connection. func (this *Connection) Close() error { - if this == nil || this.state == ConnStateClosed || this.state == ConnStateReadyToClose { + if this == nil || + this.kcp.state == StateReadyToClose || + this.kcp.state == StateTerminating || + this.kcp.state == StateTerminated { return errClosedConnection } log.Debug("KCP|Connection: Closing connection to ", this.remote) this.Lock() defer this.Unlock() - if this.state == ConnStateActive { - this.state = ConnStateReadyToClose - if this.kcp.WaitSnd() == 0 { - go this.terminateOnce.Do(this.NotifyTermination) - } else { - go this.ForceTimeout() - } - } - - if this.state == ConnStatePeerClosed { - go this.Terminate() - } + this.kcpAccess.Lock() + this.kcp.OnClose() + this.kcpAccess.Unlock() return nil } @@ -254,7 +195,7 @@ func (this *Connection) RemoteAddr() net.Addr { // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. func (this *Connection) SetDeadline(t time.Time) error { - if this == nil || this.state != ConnStateActive { + if this == nil || this.kcp.state != StateActive { return errClosedConnection } this.Lock() @@ -266,7 +207,7 @@ func (this *Connection) SetDeadline(t time.Time) error { // SetReadDeadline implements the Conn SetReadDeadline method. func (this *Connection) SetReadDeadline(t time.Time) error { - if this == nil || this.state != ConnStateActive { + if this == nil || this.kcp.state != StateActive { return errClosedConnection } this.Lock() @@ -277,7 +218,7 @@ func (this *Connection) SetReadDeadline(t time.Time) error { // SetWriteDeadline implements the Conn SetWriteDeadline method. func (this *Connection) SetWriteDeadline(t time.Time) error { - if this == nil || this.state != ConnStateActive { + if this == nil || this.kcp.state != StateActive { return errClosedConnection } this.Lock() @@ -286,54 +227,21 @@ func (this *Connection) SetWriteDeadline(t time.Time) error { return nil } -func (this *Connection) outputBuffer(payload *alloc.Buffer) { - defer payload.Release() - if this == nil { - return - } - - this.RLock() - defer this.RUnlock() - if this.state == ConnStatePeerClosed || this.state == ConnStateClosed { - return - } - this.block.Seal(payload) - - this.writer.Write(payload.Value) -} - -func (this *Connection) output(payload []byte) { - if this == nil || this.state == ConnStateClosed { - return - } - - if this.state == ConnStateReadyToClose && this.kcp.WaitSnd() == 0 { - go this.terminateOnce.Do(this.NotifyTermination) - } - - if len(payload) < IKCP_OVERHEAD { - return - } - - buffer := alloc.NewBuffer().Clear().Append(payload) - cmd := CommandData - opt := Option(0) - if this.state == ConnStateReadyToClose { - opt = OptionClose - } - buffer.Prepend([]byte{byte(cmd), byte(opt)}) - this.outputBuffer(buffer) -} - // kcp update, input loop func (this *Connection) updateTask() { - for this.state != ConnStateClosed { + for this.kcp.state != StateTerminated { current := this.Elapsed() this.kcpAccess.Lock() this.kcp.Update(current) this.kcpAccess.Unlock() - time.Sleep(time.Duration(effectiveConfig.Tti) * time.Millisecond) + + interval := time.Duration(effectiveConfig.Tti) * time.Millisecond + if this.kcp.state == StateTerminating { + interval = time.Second + } + time.Sleep(interval) } + this.Terminate() } func (this *Connection) notifyReadEvent() { @@ -343,35 +251,10 @@ func (this *Connection) notifyReadEvent() { } } -func (this *Connection) MarkPeerClose() { - this.Lock() - defer this.Unlock() - if this.state == ConnStateReadyToClose { - this.state = ConnStateClosed - go this.Terminate() - return - } - if this.state == ConnStateActive { - this.state = ConnStatePeerClosed - } - this.kcpAccess.Lock() - this.kcp.ClearSendQueue() - this.kcpAccess.Unlock() -} - func (this *Connection) kcpInput(data []byte) { - cmd := Command(data[0]) - opt := Option(data[1]) - if cmd == CommandTerminate { - go this.Terminate() - return - } - if opt == OptionClose { - go this.MarkPeerClose() - } this.kcpAccess.Lock() this.kcp.current = this.Elapsed() - this.kcp.Input(data[2:]) + this.kcp.Input(data) this.kcpAccess.Unlock() this.notifyReadEvent() @@ -383,6 +266,7 @@ func (this *Connection) FetchInputFrom(conn net.Conn) { payload := alloc.NewBuffer() nBytes, err := conn.Read(payload.Value) if err != nil { + payload.Release() return } payload.Slice(0, nBytes) @@ -401,3 +285,12 @@ func (this *Connection) Reusable() bool { } func (this *Connection) SetReusable(b bool) {} + +func (this *Connection) Terminate() { + if this == nil || this.writer == nil { + return + } + log.Info("Terminating connection to ", this.RemoteAddr()) + + this.writer.Close() +} diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index 3eabfcce..a3329469 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -18,7 +18,7 @@ func DialKCP(src v2net.Address, dest v2net.Destination) (internet.Connection, er } cpip := NewSimpleAuthenticator() - session := NewConnection(rand.Uint32(), conn, conn.LocalAddr().(*net.UDPAddr), conn.RemoteAddr().(*net.UDPAddr), cpip) + session := NewConnection(uint16(rand.Uint32()), conn, conn.LocalAddr().(*net.UDPAddr), conn.RemoteAddr().(*net.UDPAddr), cpip) session.FetchInputFrom(conn) return session, nil diff --git a/transport/internet/kcp/kcp.go b/transport/internet/kcp/kcp.go index b74674de..ef5ebf12 100644 --- a/transport/internet/kcp/kcp.go +++ b/transport/internet/kcp/kcp.go @@ -6,9 +6,9 @@ package kcp import ( - "encoding/binary" - "github.com/v2ray/v2ray-core/common/alloc" + v2io "github.com/v2ray/v2ray-core/common/io" + "github.com/v2ray/v2ray-core/common/log" ) const ( @@ -31,45 +31,6 @@ const ( IKCP_PROBE_LIMIT = 120000 // up to 120 secs to probe window ) -// Output is a closure which captures conn and calls conn.Write -type Output func(buf []byte) - -/* encode 8 bits unsigned int */ -func ikcp_encode8u(p []byte, c byte) []byte { - p[0] = c - return p[1:] -} - -/* decode 8 bits unsigned int */ -func ikcp_decode8u(p []byte, c *byte) []byte { - *c = p[0] - return p[1:] -} - -/* encode 16 bits unsigned int (lsb) */ -func ikcp_encode16u(p []byte, w uint16) []byte { - binary.LittleEndian.PutUint16(p, w) - return p[2:] -} - -/* decode 16 bits unsigned int (lsb) */ -func ikcp_decode16u(p []byte, w *uint16) []byte { - *w = binary.LittleEndian.Uint16(p) - return p[2:] -} - -/* encode 32 bits unsigned int (lsb) */ -func ikcp_encode32u(p []byte, l uint32) []byte { - binary.LittleEndian.PutUint32(p, l) - return p[4:] -} - -/* decode 32 bits unsigned int (lsb) */ -func ikcp_decode32u(p []byte, l *uint32) []byte { - *l = binary.LittleEndian.Uint32(p) - return p[4:] -} - func _imin_(a, b uint32) uint32 { if a <= b { return a @@ -90,49 +51,22 @@ func _itimediff(later, earlier uint32) int32 { return (int32)(later - earlier) } -// Segment defines a KCP segment -type Segment struct { - conv uint32 - cmd uint32 - frg uint32 - wnd uint32 - ts uint32 - sn uint32 - una uint32 - resendts uint32 - fastack uint32 - xmit uint32 - data *alloc.Buffer -} +type State int -// encode a segment into buffer -func (seg *Segment) encode(ptr []byte) []byte { - ptr = ikcp_encode32u(ptr, seg.conv) - ptr = ikcp_encode8u(ptr, uint8(seg.cmd)) - ptr = ikcp_encode8u(ptr, uint8(seg.frg)) - ptr = ikcp_encode16u(ptr, uint16(seg.wnd)) - ptr = ikcp_encode32u(ptr, seg.ts) - ptr = ikcp_encode32u(ptr, seg.sn) - ptr = ikcp_encode32u(ptr, seg.una) - ptr = ikcp_encode16u(ptr, uint16(seg.data.Len())) - return ptr -} - -func (this *Segment) Release() { - this.data.Release() - this.data = nil -} - -// NewSegment creates a KCP segment -func NewSegment() *Segment { - return &Segment{ - data: alloc.NewSmallBuffer().Clear(), - } -} +const ( + StateActive State = 0 + StateReadyToClose State = 1 + StatePeerClosed State = 2 + StateTerminating State = 3 + StateTerminated State = 4 +) // KCP defines a single KCP connection type KCP struct { - conv, mtu, mss, state uint32 + conv uint16 + state State + stateBeginTime uint32 + mtu, mss uint32 snd_una, snd_nxt, rcv_nxt uint32 ts_recent, ts_lastack, ssthresh uint32 rx_rttvar, rx_srtt, rx_rto uint32 @@ -143,21 +77,21 @@ type KCP struct { dead_link, incr uint32 snd_queue *SendingQueue - rcv_queue []*Segment - snd_buf []*Segment + rcv_queue []*DataSegment + snd_buf []*DataSegment rcv_buf *ReceivingWindow - acklist []uint32 + acklist *ACKList buffer []byte fastresend int32 congestionControl bool - output Output + output *SegmentWriter } // NewKCP create a new kcp control object, 'conv' must equal in two endpoint // from the same connection. -func NewKCP(conv uint32, mtu uint32, sendingWindowSize uint32, receivingWindowSize uint32, sendingQueueSize uint32, output Output) *KCP { +func NewKCP(conv uint16, mtu uint32, sendingWindowSize uint32, receivingWindowSize uint32, sendingQueueSize uint32, output v2io.Writer) *KCP { kcp := new(KCP) kcp.conv = conv kcp.snd_wnd = sendingWindowSize @@ -165,18 +99,51 @@ func NewKCP(conv uint32, mtu uint32, sendingWindowSize uint32, receivingWindowSi kcp.rmt_wnd = IKCP_WND_RCV kcp.mtu = mtu kcp.mss = kcp.mtu - IKCP_OVERHEAD - kcp.buffer = make([]byte, (kcp.mtu+IKCP_OVERHEAD)*3) kcp.rx_rto = IKCP_RTO_DEF kcp.interval = IKCP_INTERVAL kcp.ts_flush = IKCP_INTERVAL kcp.ssthresh = IKCP_THRESH_INIT kcp.dead_link = IKCP_DEADLINK - kcp.output = output + kcp.output = NewSegmentWriter(mtu, output) kcp.rcv_buf = NewReceivingWindow(receivingWindowSize) kcp.snd_queue = NewSendingQueue(sendingQueueSize) + kcp.acklist = new(ACKList) return kcp } +func (kcp *KCP) HandleOption(opt SegmentOption) { + if (opt & SegmentOptionClose) == SegmentOptionClose { + kcp.OnPeerClosed() + } +} + +func (kcp *KCP) OnPeerClosed() { + if kcp.state == StateReadyToClose { + kcp.state = StateTerminating + kcp.stateBeginTime = kcp.current + log.Info("KCP terminating at ", kcp.current) + } + if kcp.state == StateActive { + kcp.ClearSendQueue() + kcp.state = StatePeerClosed + kcp.stateBeginTime = kcp.current + log.Info("KCP peer close at ", kcp.current) + } +} + +func (kcp *KCP) OnClose() { + if kcp.state == StateActive { + kcp.state = StateReadyToClose + kcp.stateBeginTime = kcp.current + log.Info("KCP ready close at ", kcp.current) + } + if kcp.state == StatePeerClosed { + kcp.state = StateTerminating + kcp.stateBeginTime = kcp.current + log.Info("KCP terminating at ", kcp.current) + } +} + // Recv is user/upper level recv: returns size, returns below zero for EAGAIN func (kcp *KCP) Recv(buffer []byte) (n int) { if len(kcp.rcv_queue) == 0 { @@ -186,11 +153,11 @@ func (kcp *KCP) Recv(buffer []byte) (n int) { // merge fragment count := 0 for _, seg := range kcp.rcv_queue { - dataLen := seg.data.Len() + dataLen := seg.Data.Len() if dataLen > len(buffer) { break } - copy(buffer, seg.data.Value) + copy(buffer, seg.Data.Value) seg.Release() buffer = buffer[dataLen:] n += dataLen @@ -226,8 +193,9 @@ func (kcp *KCP) Send(buffer []byte) int { } else { size = len(buffer) } - seg := NewSegment() - seg.data.Append(buffer[:size]) + seg := &DataSegment{ + Data: alloc.NewSmallBuffer().Clear().Append(buffer[:size]), + } kcp.snd_queue.Push(seg) buffer = buffer[size:] nBytes += size @@ -262,7 +230,7 @@ func (kcp *KCP) update_ack(rtt int32) { func (kcp *KCP) shrink_buf() { if len(kcp.snd_buf) > 0 { seg := kcp.snd_buf[0] - kcp.snd_una = seg.sn + kcp.snd_una = seg.Number } else { kcp.snd_una = kcp.snd_nxt } @@ -274,12 +242,12 @@ func (kcp *KCP) parse_ack(sn uint32) { } for k, seg := range kcp.snd_buf { - if sn == seg.sn { + if sn == seg.Number { kcp.snd_buf = append(kcp.snd_buf[:k], kcp.snd_buf[k+1:]...) seg.Release() break } - if _itimediff(sn, seg.sn) < 0 { + if _itimediff(sn, seg.Number) < 0 { break } } @@ -291,18 +259,18 @@ func (kcp *KCP) parse_fastack(sn uint32) { } for _, seg := range kcp.snd_buf { - if _itimediff(sn, seg.sn) < 0 { + if _itimediff(sn, seg.Number) < 0 { break - } else if sn != seg.sn { - seg.fastack++ + } else if sn != seg.Number { + seg.ackSkipped++ } } } -func (kcp *KCP) parse_una(una uint32) { +func (kcp *KCP) HandleReceivingNext(receivingNext uint32) { count := 0 for _, seg := range kcp.snd_buf { - if _itimediff(una, seg.sn) > 0 { + if _itimediff(receivingNext, seg.Number) > 0 { seg.Release() count++ } else { @@ -312,17 +280,12 @@ func (kcp *KCP) parse_una(una uint32) { kcp.snd_buf = kcp.snd_buf[count:] } -// ack append -func (kcp *KCP) ack_push(sn, ts uint32) { - kcp.acklist = append(kcp.acklist, sn, ts) +func (kcp *KCP) HandleSendingNext(sendingNext uint32) { + kcp.acklist.Clear(sendingNext) } -func (kcp *KCP) ack_get(p int) (sn, ts uint32) { - return kcp.acklist[p*2+0], kcp.acklist[p*2+1] -} - -func (kcp *KCP) parse_data(newseg *Segment) { - sn := newseg.sn +func (kcp *KCP) parse_data(newseg *DataSegment) { + sn := newseg.Number if _itimediff(sn, kcp.rcv_nxt+kcp.rcv_wnd) >= 0 || _itimediff(sn, kcp.rcv_nxt) < 0 { return @@ -338,163 +301,132 @@ func (kcp *KCP) parse_data(newseg *Segment) { // Input when you received a low level packet (eg. UDP packet), call it func (kcp *KCP) Input(data []byte) int { - //una := kcp.snd_una - if len(data) < IKCP_OVERHEAD { - return -1 - } - + log.Info("KCP input at ", kcp.current) + var seg ISegment var maxack uint32 var flag int for { - var ts, sn, una, conv uint32 - var wnd, length uint16 - var cmd, frg uint8 - - if len(data) < int(IKCP_OVERHEAD) { + seg, data = ReadSegment(data) + if seg == nil { break } - data = ikcp_decode32u(data, &conv) - if conv != kcp.conv { - return -1 - } - - data = ikcp_decode8u(data, &cmd) - data = ikcp_decode8u(data, &frg) - data = ikcp_decode16u(data, &wnd) - data = ikcp_decode32u(data, &ts) - data = ikcp_decode32u(data, &sn) - data = ikcp_decode32u(data, &una) - data = ikcp_decode16u(data, &length) - if len(data) < int(length) { - return -2 - } - - if cmd != IKCP_CMD_PUSH && cmd != IKCP_CMD_ACK { - return -3 - } - - if kcp.rmt_wnd < uint32(wnd) { - kcp.rmt_wnd = uint32(wnd) - } - - kcp.parse_una(una) - kcp.shrink_buf() - - if cmd == IKCP_CMD_ACK { - if _itimediff(kcp.current, ts) >= 0 { - kcp.update_ack(_itimediff(kcp.current, ts)) - } - kcp.parse_ack(sn) + switch seg := seg.(type) { + case *DataSegment: + kcp.HandleOption(seg.Opt) + kcp.HandleSendingNext(seg.SendingNext) kcp.shrink_buf() - if flag == 0 { - flag = 1 - maxack = sn - } else if _itimediff(sn, maxack) > 0 { - maxack = sn + kcp.acklist.Add(seg.Number, seg.Timestamp) + kcp.parse_data(seg) + case *ACKSegment: + kcp.HandleOption(seg.Opt) + if kcp.rmt_wnd < seg.ReceivingWindow { + kcp.rmt_wnd = seg.ReceivingWindow } - } else if cmd == IKCP_CMD_PUSH { - if _itimediff(sn, kcp.rcv_nxt+kcp.rcv_wnd) < 0 { - kcp.ack_push(sn, ts) - if _itimediff(sn, kcp.rcv_nxt) >= 0 { - seg := NewSegment() - seg.conv = conv - seg.cmd = uint32(cmd) - seg.frg = uint32(frg) - seg.wnd = uint32(wnd) - seg.ts = ts - seg.sn = sn - seg.una = una - seg.data.Append(data[:length]) - kcp.parse_data(seg) + kcp.HandleReceivingNext(seg.ReceivingNext) + for i := 0; i < int(seg.Count); i++ { + ts := seg.TimestampList[i] + sn := seg.NumberList[i] + if _itimediff(kcp.current, ts) >= 0 { + kcp.update_ack(_itimediff(kcp.current, ts)) + } + kcp.parse_ack(sn) + if flag == 0 { + flag = 1 + maxack = sn + } else if _itimediff(sn, maxack) > 0 { + maxack = sn } } - } else { - return -3 + kcp.shrink_buf() + case *CmdOnlySegment: + kcp.HandleOption(seg.Opt) + if seg.Cmd == SegmentCommandTerminated { + if kcp.state == StateActive || + kcp.state == StateReadyToClose || + kcp.state == StatePeerClosed { + kcp.state = StateTerminating + kcp.stateBeginTime = kcp.current + log.Info("KCP terminating at ", kcp.current) + } else if kcp.state == StateTerminating { + kcp.state = StateTerminated + kcp.stateBeginTime = kcp.current + log.Info("KCP terminated at ", kcp.current) + } + } + kcp.HandleReceivingNext(seg.ReceivinNext) + kcp.HandleSendingNext(seg.SendingNext) + default: } - - data = data[length:] } if flag != 0 { kcp.parse_fastack(maxack) } - /* - if _itimediff(kcp.snd_una, una) > 0 { - if kcp.cwnd < kcp.rmt_wnd { - mss := kcp.mss - if kcp.cwnd < kcp.ssthresh { - kcp.cwnd++ - kcp.incr += mss - } else { - if kcp.incr < mss { - kcp.incr = mss - } - kcp.incr += (mss*mss)/kcp.incr + (mss / 16) - if (kcp.cwnd+1)*mss <= kcp.incr { - kcp.cwnd++ - } - } - if kcp.cwnd > kcp.rmt_wnd { - kcp.cwnd = kcp.rmt_wnd - kcp.incr = kcp.rmt_wnd * mss - } - } - }*/ - return 0 } // flush pending data func (kcp *KCP) flush() { - current := kcp.current - buffer := kcp.buffer - change := 0 - //lost := false - - if !kcp.updated { + if kcp.state == StateTerminated { return } - var seg Segment - seg.conv = kcp.conv - seg.cmd = IKCP_CMD_ACK - seg.wnd = uint32(kcp.rcv_nxt + kcp.rcv_wnd) - seg.una = kcp.rcv_nxt + if kcp.state == StateTerminating { + kcp.output.Write(&CmdOnlySegment{ + Conv: kcp.conv, + Cmd: SegmentCommandTerminated, + }) + kcp.output.Flush() + + if _itimediff(kcp.current, kcp.stateBeginTime) > 8000 { + kcp.state = StateTerminated + log.Info("KCP terminated at ", kcp.current) + kcp.stateBeginTime = kcp.current + } + return + } + + if kcp.state == StateReadyToClose && _itimediff(kcp.current, kcp.stateBeginTime) > 15000 { + kcp.state = StateTerminating + log.Info("KCP terminating at ", kcp.current) + kcp.stateBeginTime = kcp.current + } + + current := kcp.current + segSent := false + //lost := false + + //var seg Segment + //seg.conv = kcp.conv + //seg.cmd = IKCP_CMD_ACK + //seg.wnd = uint32(kcp.rcv_nxt + kcp.rcv_wnd) + //seg.una = kcp.rcv_nxt // flush acknowledges - count := len(kcp.acklist) / 2 - ptr := buffer - for i := 0; i < count; i++ { - size := len(buffer) - len(ptr) - if size+IKCP_OVERHEAD > int(kcp.mtu) { - kcp.output(buffer[:size]) - ptr = buffer - } - seg.sn, seg.ts = kcp.ack_get(i) - ptr = seg.encode(ptr) + ackSeg := kcp.acklist.AsSegment() + if ackSeg != nil { + ackSeg.Conv = kcp.conv + ackSeg.ReceivingWindow = uint32(kcp.rcv_nxt + kcp.rcv_wnd) + ackSeg.ReceivingNext = kcp.rcv_nxt + kcp.output.Write(ackSeg) + segSent = true } - kcp.acklist = nil // calculate window size - cwnd := _imin_(kcp.snd_una+kcp.snd_wnd, kcp.rmt_wnd) if kcp.congestionControl { cwnd = _imin_(kcp.cwnd, cwnd) } for !kcp.snd_queue.IsEmpty() && _itimediff(kcp.snd_nxt, cwnd) < 0 { - newseg := kcp.snd_queue.Pop() - newseg.conv = kcp.conv - newseg.cmd = IKCP_CMD_PUSH - newseg.wnd = seg.wnd - newseg.ts = current - newseg.sn = kcp.snd_nxt - newseg.una = kcp.rcv_nxt - newseg.resendts = current - newseg.fastack = 0 - newseg.xmit = 0 - kcp.snd_buf = append(kcp.snd_buf, newseg) + seg := kcp.snd_queue.Pop() + seg.Conv = kcp.conv + seg.Number = kcp.snd_nxt + seg.timeout = current + seg.ackSkipped = 0 + seg.transmit = 0 + kcp.snd_buf = append(kcp.snd_buf, seg) kcp.snd_nxt++ } @@ -507,51 +439,75 @@ func (kcp *KCP) flush() { // flush data segments for _, segment := range kcp.snd_buf { needsend := false - if segment.xmit == 0 { + if segment.transmit == 0 { needsend = true - segment.xmit++ - segment.resendts = current + kcp.rx_rto - } else if _itimediff(current, segment.resendts) >= 0 { + segment.transmit++ + segment.timeout = current + kcp.rx_rto + } else if _itimediff(current, segment.timeout) >= 0 { needsend = true - segment.xmit++ + segment.transmit++ kcp.xmit++ - segment.resendts = current + kcp.rx_rto + segment.timeout = current + kcp.rx_rto //lost = true - } else if segment.fastack >= resent { + } else if segment.ackSkipped >= resent { needsend = true - segment.xmit++ - segment.fastack = 0 - segment.resendts = current + kcp.rx_rto - change++ + segment.transmit++ + segment.ackSkipped = 0 + segment.timeout = current + kcp.rx_rto } if needsend { - segment.ts = current - segment.wnd = seg.wnd - segment.una = kcp.rcv_nxt - - size := len(buffer) - len(ptr) - need := IKCP_OVERHEAD + segment.data.Len() - - if size+need >= int(kcp.mtu) { - kcp.output(buffer[:size]) - ptr = buffer + segment.Timestamp = current + segment.SendingNext = kcp.snd_una + segment.Opt = 0 + if kcp.state == StateReadyToClose { + segment.Opt = SegmentOptionClose } - ptr = segment.encode(ptr) - copy(ptr, segment.data.Value) - ptr = ptr[segment.data.Len():] + kcp.output.Write(segment) + segSent = true - if segment.xmit >= kcp.dead_link { + if segment.transmit >= kcp.dead_link { kcp.state = 0xFFFFFFFF } } } // flash remain segments - size := len(buffer) - len(ptr) - if size > 0 { - kcp.output(buffer[:size]) + kcp.output.Flush() + + if !segSent && kcp.state == StateReadyToClose { + kcp.output.Write(&CmdOnlySegment{ + Conv: kcp.conv, + Cmd: SegmentCommandPing, + Opt: SegmentOptionClose, + ReceivinNext: kcp.rcv_nxt, + SendingNext: kcp.snd_nxt, + }) + kcp.output.Flush() + segSent = true + } + + if !segSent && kcp.state == StateTerminating { + kcp.output.Write(&CmdOnlySegment{ + Conv: kcp.conv, + Cmd: SegmentCommandTerminated, + ReceivinNext: kcp.rcv_nxt, + SendingNext: kcp.snd_una, + }) + kcp.output.Flush() + segSent = true + } + + if !segSent { + kcp.output.Write(&CmdOnlySegment{ + Conv: kcp.conv, + Cmd: SegmentCommandPing, + ReceivinNext: kcp.rcv_nxt, + SendingNext: kcp.snd_una, + }) + kcp.output.Flush() + segSent = true } // update ssthresh @@ -613,54 +569,6 @@ func (kcp *KCP) Update(current uint32) { } } -// Check determines when should you invoke ikcp_update: -// returns when you should invoke ikcp_update in millisec, if there -// is no ikcp_input/_send calling. you can call ikcp_update in that -// time, instead of call update repeatly. -// Important to reduce unnacessary ikcp_update invoking. use it to -// schedule ikcp_update (eg. implementing an epoll-like mechanism, -// or optimize ikcp_update when handling massive kcp connections) -func (kcp *KCP) Check(current uint32) uint32 { - ts_flush := kcp.ts_flush - tm_flush := int32(0x7fffffff) - tm_packet := int32(0x7fffffff) - minimal := uint32(0) - if !kcp.updated { - return current - } - - if _itimediff(current, ts_flush) >= 10000 || - _itimediff(current, ts_flush) < -10000 { - ts_flush = current - } - - if _itimediff(current, ts_flush) >= 0 { - return current - } - - tm_flush = _itimediff(ts_flush, current) - - for _, seg := range kcp.snd_buf { - diff := _itimediff(seg.resendts, current) - if diff <= 0 { - return current - } - if diff < tm_packet { - tm_packet = diff - } - } - - minimal = uint32(tm_packet) - if tm_packet >= tm_flush { - minimal = uint32(tm_flush) - } - if minimal >= kcp.interval { - minimal = kcp.interval - } - - return current + minimal -} - // NoDelay options // fastest: ikcp_nodelay(kcp, 1, 20, 2, 1) // nodelay: 0:disable(default), 1:enable diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 494468cf..6259b88c 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -1,7 +1,6 @@ package kcp import ( - "encoding/binary" "net" "sync" "time" @@ -9,6 +8,7 @@ import ( "github.com/v2ray/v2ray-core/common/alloc" "github.com/v2ray/v2ray-core/common/log" v2net "github.com/v2ray/v2ray-core/common/net" + "github.com/v2ray/v2ray-core/common/serial" "github.com/v2ray/v2ray-core/transport/internet" "github.com/v2ray/v2ray-core/transport/internet/udp" ) @@ -62,7 +62,7 @@ func (this *Listener) OnReceive(payload *alloc.Buffer, src v2net.Destination) { srcAddrStr := src.NetAddr() conn, found := this.sessions[srcAddrStr] if !found { - conv := binary.LittleEndian.Uint32(payload.Value[2:6]) + conv := serial.BytesToUint16(payload.Value) writer := &Writer{ hub: this.hub, dest: src, diff --git a/transport/internet/kcp/output.go b/transport/internet/kcp/output.go index ecff0668..8415f061 100644 --- a/transport/internet/kcp/output.go +++ b/transport/internet/kcp/output.go @@ -1,6 +1,7 @@ package kcp import ( + "io" "sync" "github.com/v2ray/v2ray-core/common/alloc" @@ -34,7 +35,7 @@ func (this *SegmentWriter) Write(seg ISegment) { this.buffer = alloc.NewSmallBuffer().Clear() } - this.buffer.Value = seg.Bytes(this.buffer.Value) + this.buffer.Append(seg.Bytes(nil)) } func (this *SegmentWriter) FlushWithoutLock() { @@ -52,3 +53,18 @@ func (this *SegmentWriter) Flush() { this.FlushWithoutLock() } + +type AuthenticationWriter struct { + Authenticator Authenticator + Writer io.Writer +} + +func (this *AuthenticationWriter) Write(payload *alloc.Buffer) error { + defer payload.Release() + + this.Authenticator.Seal(payload) + _, err := this.Writer.Write(payload.Value) + return err +} + +func (this *AuthenticationWriter) Release() {} diff --git a/transport/internet/kcp/receiving.go b/transport/internet/kcp/receiving.go index c6a8b6b3..2626846e 100644 --- a/transport/internet/kcp/receiving.go +++ b/transport/internet/kcp/receiving.go @@ -3,14 +3,14 @@ package kcp type ReceivingWindow struct { start uint32 size uint32 - list []*Segment + list []*DataSegment } func NewReceivingWindow(size uint32) *ReceivingWindow { return &ReceivingWindow{ start: 0, size: size, - list: make([]*Segment, size), + list: make([]*DataSegment, size), } } @@ -22,7 +22,7 @@ func (this *ReceivingWindow) Position(idx uint32) uint32 { return (idx + this.start) % this.size } -func (this *ReceivingWindow) Set(idx uint32, value *Segment) bool { +func (this *ReceivingWindow) Set(idx uint32, value *DataSegment) bool { pos := this.Position(idx) if this.list[pos] != nil { return false @@ -31,14 +31,14 @@ func (this *ReceivingWindow) Set(idx uint32, value *Segment) bool { return true } -func (this *ReceivingWindow) Remove(idx uint32) *Segment { +func (this *ReceivingWindow) Remove(idx uint32) *DataSegment { pos := this.Position(idx) e := this.list[pos] this.list[pos] = nil return e } -func (this *ReceivingWindow) RemoveFirst() *Segment { +func (this *ReceivingWindow) RemoveFirst() *DataSegment { return this.Remove(0) } @@ -76,12 +76,19 @@ func (this *ACKList) Clear(una uint32) { func (this *ACKList) AsSegment() *ACKSegment { count := len(this.numbers) - if count > 16 { - count = 16 + if count == 0 { + return nil } - return &ACKSegment{ + + if count > 128 { + count = 128 + } + seg := &ACKSegment{ Count: byte(count), NumberList: this.numbers[:count], TimestampList: this.timestamps[:count], } + //this.numbers = nil + //this.timestamps = nil + return seg } diff --git a/transport/internet/kcp/receiving_test.go b/transport/internet/kcp/receiving_test.go index 3945b859..24d03657 100644 --- a/transport/internet/kcp/receiving_test.go +++ b/transport/internet/kcp/receiving_test.go @@ -12,10 +12,10 @@ func TestRecivingWindow(t *testing.T) { window := NewReceivingWindow(3) - seg0 := &Segment{} - seg1 := &Segment{} - seg2 := &Segment{} - seg3 := &Segment{} + seg0 := &DataSegment{} + seg1 := &DataSegment{} + seg2 := &DataSegment{} + seg3 := &DataSegment{} assert.Bool(window.Set(0, seg0)).IsTrue() assert.Pointer(window.RemoveFirst()).Equals(seg0) diff --git a/transport/internet/kcp/segment.go b/transport/internet/kcp/segment.go index 285aa727..9d4fba02 100644 --- a/transport/internet/kcp/segment.go +++ b/transport/internet/kcp/segment.go @@ -3,6 +3,7 @@ package kcp import ( "github.com/v2ray/v2ray-core/common" "github.com/v2ray/v2ray-core/common/alloc" + _ "github.com/v2ray/v2ray-core/common/log" "github.com/v2ray/v2ray-core/common/serial" ) @@ -12,6 +13,7 @@ const ( SegmentCommandACK SegmentCommand = 0 SegmentCommandData SegmentCommand = 1 SegmentCommandTerminated SegmentCommand = 2 + SegmentCommandPing SegmentCommand = 3 ) type SegmentOption byte @@ -27,13 +29,12 @@ type ISegment interface { } type DataSegment struct { - Conv uint16 - Opt SegmentOption - ReceivingWindow uint32 - Timestamp uint32 - Number uint32 - Unacknowledged uint32 - Data *alloc.Buffer + Conv uint16 + Opt SegmentOption + Timestamp uint32 + Number uint32 + SendingNext uint32 + Data *alloc.Buffer timeout uint32 ackSkipped uint32 @@ -43,17 +44,16 @@ type DataSegment struct { func (this *DataSegment) Bytes(b []byte) []byte { b = serial.Uint16ToBytes(this.Conv, b) b = append(b, byte(SegmentCommandData), byte(this.Opt)) - b = serial.Uint32ToBytes(this.ReceivingWindow, b) b = serial.Uint32ToBytes(this.Timestamp, b) b = serial.Uint32ToBytes(this.Number, b) - b = serial.Uint32ToBytes(this.Unacknowledged, b) + b = serial.Uint32ToBytes(this.SendingNext, b) b = serial.Uint16ToBytes(uint16(this.Data.Len()), b) b = append(b, this.Data.Value...) return b } func (this *DataSegment) ByteSize() int { - return 2 + 1 + 1 + 4 + 4 + 4 + 4 + 2 + this.Data.Len() + return 2 + 1 + 1 + 4 + 4 + 4 + 2 + this.Data.Len() } func (this *DataSegment) Release() { @@ -64,21 +64,21 @@ type ACKSegment struct { Conv uint16 Opt SegmentOption ReceivingWindow uint32 - Unacknowledged uint32 + ReceivingNext uint32 Count byte NumberList []uint32 TimestampList []uint32 } func (this *ACKSegment) ByteSize() int { - return 2 + 1 + 1 + 4 + 4 + 1 + len(this.NumberList)*4 + len(this.TimestampList)*4 + return 2 + 1 + 1 + 4 + 4 + 1 + int(this.Count)*4 + int(this.Count)*4 } func (this *ACKSegment) Bytes(b []byte) []byte { b = serial.Uint16ToBytes(this.Conv, b) b = append(b, byte(SegmentCommandACK), byte(this.Opt)) b = serial.Uint32ToBytes(this.ReceivingWindow, b) - b = serial.Uint32ToBytes(this.Unacknowledged, b) + b = serial.Uint32ToBytes(this.ReceivingNext, b) b = append(b, this.Count) for i := byte(0); i < this.Count; i++ { b = serial.Uint32ToBytes(this.NumberList[i], b) @@ -89,25 +89,30 @@ func (this *ACKSegment) Bytes(b []byte) []byte { func (this *ACKSegment) Release() {} -type TerminationSegment struct { - Conv uint16 - Opt SegmentOption +type CmdOnlySegment struct { + Conv uint16 + Cmd SegmentCommand + Opt SegmentOption + SendingNext uint32 + ReceivinNext uint32 } -func (this *TerminationSegment) ByteSize() int { - return 2 + 1 + 1 +func (this *CmdOnlySegment) ByteSize() int { + return 2 + 1 + 1 + 4 + 4 } -func (this *TerminationSegment) Bytes(b []byte) []byte { +func (this *CmdOnlySegment) Bytes(b []byte) []byte { b = serial.Uint16ToBytes(this.Conv, b) - b = append(b, byte(SegmentCommandTerminated), byte(this.Opt)) + b = append(b, byte(this.Cmd), byte(this.Opt)) + b = serial.Uint32ToBytes(this.SendingNext, b) + b = serial.Uint32ToBytes(this.ReceivinNext, b) return b } -func (this *TerminationSegment) Release() {} +func (this *CmdOnlySegment) Release() {} func ReadSegment(buf []byte) (ISegment, []byte) { - if len(buf) <= 12 { + if len(buf) <= 6 { return nil, nil } @@ -123,16 +128,13 @@ func ReadSegment(buf []byte) (ISegment, []byte) { Conv: conv, Opt: opt, } - seg.ReceivingWindow = serial.BytesToUint32(buf) - buf = buf[4:] - seg.Timestamp = serial.BytesToUint32(buf) buf = buf[4:] seg.Number = serial.BytesToUint32(buf) buf = buf[4:] - seg.Unacknowledged = serial.BytesToUint32(buf) + seg.SendingNext = serial.BytesToUint32(buf) buf = buf[4:] len := serial.BytesToUint16(buf) @@ -152,7 +154,7 @@ func ReadSegment(buf []byte) (ISegment, []byte) { seg.ReceivingWindow = serial.BytesToUint32(buf) buf = buf[4:] - seg.Unacknowledged = serial.BytesToUint32(buf) + seg.ReceivingNext = serial.BytesToUint32(buf) buf = buf[4:] seg.Count = buf[0] @@ -170,12 +172,17 @@ func ReadSegment(buf []byte) (ISegment, []byte) { return seg, buf } - if cmd == SegmentCommandTerminated { - return &TerminationSegment{ - Conv: conv, - Opt: opt, - }, buf + seg := &CmdOnlySegment{ + Conv: conv, + Cmd: cmd, + Opt: opt, } - return nil, nil + seg.SendingNext = serial.BytesToUint32(buf) + buf = buf[4:] + + seg.ReceivinNext = serial.BytesToUint32(buf) + buf = buf[4:] + + return seg, buf } diff --git a/transport/internet/kcp/segment_test.go b/transport/internet/kcp/segment_test.go index c1fc1ac5..49f98880 100644 --- a/transport/internet/kcp/segment_test.go +++ b/transport/internet/kcp/segment_test.go @@ -20,12 +20,11 @@ func TestDataSegment(t *testing.T) { assert := assert.On(t) seg := &DataSegment{ - Conv: 1, - ReceivingWindow: 2, - Timestamp: 3, - Number: 4, - Unacknowledged: 5, - Data: alloc.NewSmallBuffer().Clear().Append([]byte{'a', 'b', 'c', 'd'}), + Conv: 1, + Timestamp: 3, + Number: 4, + SendingNext: 5, + Data: alloc.NewSmallBuffer().Clear().Append([]byte{'a', 'b', 'c', 'd'}), } nBytes := seg.ByteSize() @@ -36,9 +35,8 @@ func TestDataSegment(t *testing.T) { iseg, _ := ReadSegment(bytes) seg2 := iseg.(*DataSegment) assert.Uint16(seg2.Conv).Equals(seg.Conv) - assert.Uint32(seg2.ReceivingWindow).Equals(seg.ReceivingWindow) assert.Uint32(seg2.Timestamp).Equals(seg.Timestamp) - assert.Uint32(seg2.Unacknowledged).Equals(seg.Unacknowledged) + assert.Uint32(seg2.SendingNext).Equals(seg.SendingNext) assert.Uint32(seg2.Number).Equals(seg.Number) assert.Bytes(seg2.Data.Value).Equals(seg.Data.Value) } @@ -49,7 +47,7 @@ func TestACKSegment(t *testing.T) { seg := &ACKSegment{ Conv: 1, ReceivingWindow: 2, - Unacknowledged: 3, + ReceivingNext: 3, Count: 5, NumberList: []uint32{1, 3, 5, 7, 9}, TimestampList: []uint32{2, 4, 6, 8, 10}, @@ -64,7 +62,7 @@ func TestACKSegment(t *testing.T) { seg2 := iseg.(*ACKSegment) assert.Uint16(seg2.Conv).Equals(seg.Conv) assert.Uint32(seg2.ReceivingWindow).Equals(seg.ReceivingWindow) - assert.Uint32(seg2.Unacknowledged).Equals(seg.Unacknowledged) + assert.Uint32(seg2.ReceivingNext).Equals(seg.ReceivingNext) assert.Byte(seg2.Count).Equals(seg.Count) for i := byte(0); i < seg2.Count; i++ { assert.Uint32(seg2.TimestampList[i]).Equals(seg.TimestampList[i]) diff --git a/transport/internet/kcp/sending.go b/transport/internet/kcp/sending.go index e63c30e9..9f28ee26 100644 --- a/transport/internet/kcp/sending.go +++ b/transport/internet/kcp/sending.go @@ -4,14 +4,14 @@ type SendingQueue struct { start uint32 cap uint32 len uint32 - list []*Segment + list []*DataSegment } func NewSendingQueue(size uint32) *SendingQueue { return &SendingQueue{ start: 0, cap: size, - list: make([]*Segment, size), + list: make([]*DataSegment, size), len: 0, } } @@ -24,7 +24,7 @@ func (this *SendingQueue) IsEmpty() bool { return this.len == 0 } -func (this *SendingQueue) Pop() *Segment { +func (this *SendingQueue) Pop() *DataSegment { if this.IsEmpty() { return nil } @@ -38,7 +38,7 @@ func (this *SendingQueue) Pop() *Segment { return seg } -func (this *SendingQueue) Push(seg *Segment) { +func (this *SendingQueue) Push(seg *DataSegment) { if this.IsFull() { return } diff --git a/transport/internet/kcp/sending_test.go b/transport/internet/kcp/sending_test.go index 35224691..ca6486fc 100644 --- a/transport/internet/kcp/sending_test.go +++ b/transport/internet/kcp/sending_test.go @@ -12,10 +12,10 @@ func TestSendingQueue(t *testing.T) { queue := NewSendingQueue(3) - seg0 := &Segment{} - seg1 := &Segment{} - seg2 := &Segment{} - seg3 := &Segment{} + seg0 := &DataSegment{} + seg1 := &DataSegment{} + seg2 := &DataSegment{} + seg3 := &DataSegment{} assert.Bool(queue.IsEmpty()).IsTrue() assert.Bool(queue.IsFull()).IsFalse() @@ -44,10 +44,10 @@ func TestSendingQueueClear(t *testing.T) { queue := NewSendingQueue(3) - seg0 := &Segment{} - seg1 := &Segment{} - seg2 := &Segment{} - seg3 := &Segment{} + seg0 := &DataSegment{} + seg1 := &DataSegment{} + seg2 := &DataSegment{} + seg3 := &DataSegment{} queue.Push(seg0) assert.Bool(queue.IsEmpty()).IsFalse()