diff --git a/storage/local/chunk/delta.go b/storage/local/chunk/delta.go index 2be00cce9..4e3fd0645 100644 --- a/storage/local/chunk/delta.go +++ b/storage/local/chunk/delta.go @@ -238,27 +238,36 @@ func (c *deltaEncodedChunk) Unmarshal(r io.Reader) error { if _, err := io.ReadFull(r, *c); err != nil { return err } - l := binary.LittleEndian.Uint16((*c)[deltaHeaderBufLenOffset:]) - if int(l) > cap(*c) { - return fmt.Errorf("chunk length exceeded during unmarshaling: %d", l) - } - if int(l) < deltaHeaderBytes { - return fmt.Errorf("chunk length less than header size: %d < %d", l, deltaHeaderBytes) - } - *c = (*c)[:l] - return nil + return c.setLen() } // UnmarshalFromBuf implements chunk. func (c *deltaEncodedChunk) UnmarshalFromBuf(buf []byte) error { *c = (*c)[:cap(*c)] copy(*c, buf) + return c.setLen() +} + +// setLen sets the length of the underlying slice and performs some sanity checks. +func (c *deltaEncodedChunk) setLen() error { l := binary.LittleEndian.Uint16((*c)[deltaHeaderBufLenOffset:]) if int(l) > cap(*c) { - return fmt.Errorf("chunk length exceeded during unmarshaling: %d", l) + return fmt.Errorf("delta chunk length exceeded during unmarshaling: %d", l) } if int(l) < deltaHeaderBytes { - return fmt.Errorf("chunk length less than header size: %d < %d", l, deltaHeaderBytes) + return fmt.Errorf("delta chunk length less than header size: %d < %d", l, deltaHeaderBytes) + } + switch c.timeBytes() { + case d1, d2, d4, d8: + // Pass. + default: + return fmt.Errorf("invalid number of time bytes in delta chunk: %d", c.timeBytes()) + } + switch c.valueBytes() { + case d0, d1, d2, d4, d8: + // Pass. + default: + return fmt.Errorf("invalid number of value bytes in delta chunk: %d", c.valueBytes()) } *c = (*c)[:l] return nil diff --git a/storage/local/chunk/delta_test.go b/storage/local/chunk/delta_test.go index 357929574..09fd35c21 100644 --- a/storage/local/chunk/delta_test.go +++ b/storage/local/chunk/delta_test.go @@ -32,21 +32,20 @@ func TestUnmarshalingCorruptedDeltaReturnsAnError(t *testing.T) { err error, chunkTypeName string, unmarshalMethod string, - badLen int) { + expectedStr string, + ) { if err == nil { - t.Errorf("Failed to obtain an error when unmarshalling %s (from %s) with corrupt length of %d", chunkTypeName, unmarshalMethod, badLen) + t.Errorf("Failed to obtain an error when unmarshalling corrupt %s (from %s)", chunkTypeName, unmarshalMethod) return } - expectedStr := "header size" if !strings.Contains(err.Error(), expectedStr) { t.Errorf( - "'%s' not present in error when unmarshalling %s (from %s) with corrupt length %d: '%s'", + "'%s' not present in error when unmarshalling corrupt %s (from %s): '%s'", expectedStr, chunkTypeName, unmarshalMethod, - badLen, err.Error()) } } @@ -56,6 +55,7 @@ func TestUnmarshalingCorruptedDeltaReturnsAnError(t *testing.T) { chunkConstructor func(deltaBytes, deltaBytes, bool, int) Chunk minHeaderLen int chunkLenPos int + timeBytesPos int }{ { chunkTypeName: "deltaEncodedChunk", @@ -64,6 +64,7 @@ func TestUnmarshalingCorruptedDeltaReturnsAnError(t *testing.T) { }, minHeaderLen: deltaHeaderBytes, chunkLenPos: deltaHeaderBufLenOffset, + timeBytesPos: deltaHeaderTimeBytesOffset, }, { chunkTypeName: "doubleDeltaEncodedChunk", @@ -72,6 +73,7 @@ func TestUnmarshalingCorruptedDeltaReturnsAnError(t *testing.T) { }, minHeaderLen: doubleDeltaHeaderMinBytes, chunkLenPos: doubleDeltaHeaderBufLenOffset, + timeBytesPos: doubleDeltaHeaderTimeBytesOffset, }, } for _, c := range cases { @@ -89,15 +91,26 @@ func TestUnmarshalingCorruptedDeltaReturnsAnError(t *testing.T) { cs[0].MarshalToBuf(buf) + // Corrupt time byte to 0, which is illegal. + buf[c.timeBytesPos] = 0 + err = cs[0].UnmarshalFromBuf(buf) + verifyUnmarshallingError(err, c.chunkTypeName, "buf", "invalid number of time bytes") + + err = cs[0].Unmarshal(bytes.NewBuffer(buf)) + verifyUnmarshallingError(err, c.chunkTypeName, "Reader", "invalid number of time bytes") + + // Fix the corruption to go on. + buf[c.timeBytesPos] = byte(d1) + // Corrupt the length to be every possible too-small value for i := 0; i < c.minHeaderLen; i++ { binary.LittleEndian.PutUint16(buf[c.chunkLenPos:], uint16(i)) err = cs[0].UnmarshalFromBuf(buf) - verifyUnmarshallingError(err, c.chunkTypeName, "buf", i) + verifyUnmarshallingError(err, c.chunkTypeName, "buf", "header size") err = cs[0].Unmarshal(bytes.NewBuffer(buf)) - verifyUnmarshallingError(err, c.chunkTypeName, "Reader", i) + verifyUnmarshallingError(err, c.chunkTypeName, "Reader", "header size") } } } diff --git a/storage/local/chunk/doubledelta.go b/storage/local/chunk/doubledelta.go index 2a1221461..249c99d54 100644 --- a/storage/local/chunk/doubledelta.go +++ b/storage/local/chunk/doubledelta.go @@ -247,28 +247,36 @@ func (c *doubleDeltaEncodedChunk) Unmarshal(r io.Reader) error { if _, err := io.ReadFull(r, *c); err != nil { return err } - l := binary.LittleEndian.Uint16((*c)[doubleDeltaHeaderBufLenOffset:]) - if int(l) > cap(*c) { - return fmt.Errorf("chunk length exceeded during unmarshaling: %d", l) - } - if int(l) < doubleDeltaHeaderMinBytes { - return fmt.Errorf("chunk length less than header size: %d < %d", l, doubleDeltaHeaderMinBytes) - } - - *c = (*c)[:l] - return nil + return c.setLen() } // UnmarshalFromBuf implements chunk. func (c *doubleDeltaEncodedChunk) UnmarshalFromBuf(buf []byte) error { *c = (*c)[:cap(*c)] copy(*c, buf) + return c.setLen() +} + +// setLen sets the length of the underlying slice and performs some sanity checks. +func (c *doubleDeltaEncodedChunk) setLen() error { l := binary.LittleEndian.Uint16((*c)[doubleDeltaHeaderBufLenOffset:]) if int(l) > cap(*c) { - return fmt.Errorf("chunk length exceeded during unmarshaling: %d", l) + return fmt.Errorf("doubledelta chunk length exceeded during unmarshaling: %d", l) } if int(l) < doubleDeltaHeaderMinBytes { - return fmt.Errorf("chunk length less than header size: %d < %d", l, doubleDeltaHeaderMinBytes) + return fmt.Errorf("doubledelta chunk length less than header size: %d < %d", l, doubleDeltaHeaderMinBytes) + } + switch c.timeBytes() { + case d1, d2, d4, d8: + // Pass. + default: + return fmt.Errorf("invalid number of time bytes in doubledelta chunk: %d", c.timeBytes()) + } + switch c.valueBytes() { + case d0, d1, d2, d4, d8: + // Pass. + default: + return fmt.Errorf("invalid number of value bytes in doubledelta chunk: %d", c.valueBytes()) } *c = (*c)[:l] return nil