diff --git a/chunks/bstream.go b/chunks/bstream.go index 79f7f74ee..6c3d4fee7 100644 --- a/chunks/bstream.go +++ b/chunks/bstream.go @@ -4,11 +4,8 @@ import "io" // bstream is a stream of bits type bstream struct { - // the data stream - stream []byte - - count uint8 // how many bits are valid in current byte - shift uint8 // pos of next bit in current byte + stream []byte // the data stream + count uint8 // how many bits are valid in current byte } func newBReader(b []byte) *bstream { @@ -83,32 +80,22 @@ func (b *bstream) writeBits(u uint64, nbits int) { } } -func (b *bstream) headByte() byte { - return b.stream[0] << b.shift -} - -func (b *bstream) advance() { - b.stream = b.stream[1:] - b.shift = 0 -} - func (b *bstream) readBit() (bit, error) { if len(b.stream) == 0 { return false, io.EOF } if b.count == 0 { - b.advance() - // did we just run out of stuff to read? + b.stream = b.stream[1:] + if len(b.stream) == 0 { return false, io.EOF } b.count = 8 } - d := b.headByte() & 0x80 + d := (b.stream[0] << (8 - b.count)) & 0x80 b.count-- - b.shift++ return d != 0, nil } @@ -118,22 +105,21 @@ func (b *bstream) readByte() (byte, error) { } if b.count == 0 { - b.advance() + b.stream = b.stream[1:] if len(b.stream) == 0 { return 0, io.EOF } - - b.count = 8 + return b.stream[0], nil } if b.count == 8 { b.count = 0 - return b.headByte(), nil + return b.stream[0], nil } - byt := b.headByte() - b.advance() + byt := b.stream[0] << (8 - b.count) + b.stream = b.stream[1:] if len(b.stream) == 0 { return 0, io.EOF @@ -141,7 +127,6 @@ func (b *bstream) readByte() (byte, error) { // We just advanced the stream and can assume the shift to be 0. byt |= b.stream[0] >> b.count - b.shift = 8 - b.count return byt, nil } @@ -164,9 +149,9 @@ func (b *bstream) readBits(nbits int) (uint64, error) { } if nbits > int(b.count) { - u = (u << uint(b.count)) | uint64(b.headByte()>>(8-b.count)) + u = (u << uint(b.count)) | uint64((b.stream[0]<<(8-b.count))>>(8-b.count)) nbits -= int(b.count) - b.advance() + b.stream = b.stream[1:] if len(b.stream) == 0 { return 0, io.EOF @@ -174,8 +159,7 @@ func (b *bstream) readBits(nbits int) (uint64, error) { b.count = 8 } - u = (u << uint(nbits)) | uint64(b.headByte()>>(8-uint(nbits))) - b.shift = b.shift + uint8(nbits) + u = (u << uint(nbits)) | uint64((b.stream[0]<<(8-b.count))>>(8-uint(nbits))) b.count -= uint8(nbits) return u, nil }