diff --git a/transport/internet/kcp/segment.go b/transport/internet/kcp/segment.go index 5acf25be..9af1629d 100644 --- a/transport/internet/kcp/segment.go +++ b/transport/internet/kcp/segment.go @@ -31,6 +31,7 @@ type Segment interface { Command() Command ByteSize() int Bytes() buf.Supplier + parse(conv uint16, cmd Command, opt SegmentOption, buf []byte) (bool, []byte) } const ( @@ -53,6 +54,34 @@ func NewDataSegment() *DataSegment { return new(DataSegment) } +func (s *DataSegment) parse(conv uint16, cmd Command, opt SegmentOption, buf []byte) (bool, []byte) { + s.Conv = conv + s.Option = opt + if len(buf) < 15 { + return false, nil + } + s.Timestamp = serial.BytesToUint32(buf) + buf = buf[4:] + + s.Number = serial.BytesToUint32(buf) + buf = buf[4:] + + s.SendingNext = serial.BytesToUint32(buf) + buf = buf[4:] + + dataLen := int(serial.BytesToUint16(buf)) + buf = buf[2:] + + if len(buf) < dataLen { + return false, nil + } + s.Data().Clear() + s.Data().Append(buf[:dataLen]) + buf = buf[dataLen:] + + return true, buf +} + func (s *DataSegment) Conversation() uint16 { return s.Conv } @@ -113,6 +142,36 @@ func NewAckSegment() *AckSegment { } } +func (s *AckSegment) parse(conv uint16, cmd Command, opt SegmentOption, buf []byte) (bool, []byte) { + s.Conv = conv + s.Option = opt + if len(buf) < 13 { + return false, nil + } + + s.ReceivingWindow = serial.BytesToUint32(buf) + buf = buf[4:] + + s.ReceivingNext = serial.BytesToUint32(buf) + buf = buf[4:] + + s.Timestamp = serial.BytesToUint32(buf) + buf = buf[4:] + + count := int(buf[0]) + buf = buf[1:] + + if len(buf) < count*4 { + return false, nil + } + for i := 0; i < count; i++ { + s.PutNumber(serial.BytesToUint32(buf)) + buf = buf[4:] + } + + return true, buf +} + func (s *AckSegment) Conversation() uint16 { return s.Conv } @@ -176,6 +235,27 @@ func NewCmdOnlySegment() *CmdOnlySegment { return new(CmdOnlySegment) } +func (s *CmdOnlySegment) parse(conv uint16, cmd Command, opt SegmentOption, buf []byte) (bool, []byte) { + s.Conv = conv + s.Cmd = cmd + s.Option = opt + + if len(buf) < 12 { + return false, nil + } + + s.SendingNext = serial.BytesToUint32(buf) + buf = buf[4:] + + s.ReceivingNext = serial.BytesToUint32(buf) + buf = buf[4:] + + s.PeerRTO = serial.BytesToUint32(buf) + buf = buf[4:] + + return true, buf +} + func (s *CmdOnlySegment) Conversation() uint16 { return s.Conv } @@ -213,83 +293,19 @@ func ReadSegment(buf []byte) (Segment, []byte) { opt := SegmentOption(buf[1]) buf = buf[2:] - if cmd == CommandData { - seg := NewDataSegment() - seg.Conv = conv - seg.Option = opt - if len(buf) < 15 { - return nil, nil - } - seg.Timestamp = serial.BytesToUint32(buf) - buf = buf[4:] - - seg.Number = serial.BytesToUint32(buf) - buf = buf[4:] - - seg.SendingNext = serial.BytesToUint32(buf) - buf = buf[4:] - - dataLen := int(serial.BytesToUint16(buf)) - buf = buf[2:] - - if len(buf) < dataLen { - return nil, nil - } - seg.Data().Clear() - seg.Data().Append(buf[:dataLen]) - buf = buf[dataLen:] - - return seg, buf + var seg Segment + switch cmd { + case CommandData: + seg = NewDataSegment() + case CommandACK: + seg = NewAckSegment() + default: + seg = NewCmdOnlySegment() } - if cmd == CommandACK { - seg := NewAckSegment() - seg.Conv = conv - seg.Option = opt - if len(buf) < 13 { - return nil, nil - } - - seg.ReceivingWindow = serial.BytesToUint32(buf) - buf = buf[4:] - - seg.ReceivingNext = serial.BytesToUint32(buf) - buf = buf[4:] - - seg.Timestamp = serial.BytesToUint32(buf) - buf = buf[4:] - - count := int(buf[0]) - buf = buf[1:] - - if len(buf) < count*4 { - return nil, nil - } - for i := 0; i < count; i++ { - seg.PutNumber(serial.BytesToUint32(buf)) - buf = buf[4:] - } - - return seg, buf - } - - seg := NewCmdOnlySegment() - seg.Conv = conv - seg.Cmd = cmd - seg.Option = opt - - if len(buf) < 12 { + valid, extra := seg.parse(conv, cmd, opt, buf) + if !valid { return nil, nil } - - seg.SendingNext = serial.BytesToUint32(buf) - buf = buf[4:] - - seg.ReceivingNext = serial.BytesToUint32(buf) - buf = buf[4:] - - seg.PeerRTO = serial.BytesToUint32(buf) - buf = buf[4:] - - return seg, buf + return seg, extra }