Browse Source

fix typo

pull/5885/head
GuangyaoZhang 5 months ago
parent
commit
dbfa7d39fc
  1. 2
      colossalai/quantization/fp8.py

2
colossalai/quantization/fp8.py

@ -14,7 +14,7 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens
Returns:
Tuples: A tuple (fp8_tensor, scale)
"""
if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]:
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
return inp
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2

Loading…
Cancel
Save