You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_fp8/test_fp8_cast.py

27 lines
1.0 KiB

import torch
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
from colossalai.testing import parameterize
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def test_fp8_cast(shape, dtype, fp8_format):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format)
out = cast_from_fp8(ret, scale_inv, x.dtype)
assert_close(out, x, rtol=0.1, atol=0.1)
if x.size(-1) % 2 == 0:
inp_dict = {"hidden_states": x.clone()}
cast_to_fp8_pipeline(inp_dict)
cast_from_fp8_pipeline(inp_dict)
assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1)
if __name__ == "__main__":
test_fp8_cast()