diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index c880cd4aa..58cedbc95 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -14,7 +14,7 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens Returns: Tuples: A tuple (fp8_tensor, scale) """ - if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]: + if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: return inp fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2