From dbfa7d39fc06534cf3d44ba8d1a5ae4d147d7133 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Wed, 10 Jul 2024 08:13:26 +0000 Subject: [PATCH] fix typo --- colossalai/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index c880cd4aa..58cedbc95 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -14,7 +14,7 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens Returns: Tuples: A tuple (fp8_tensor, scale) """ - if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]: + if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: return inp fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2