pull/5885/head
GuangyaoZhang 2024-07-10 08:13:26 +00:00
parent e17f835df7
commit dbfa7d39fc
1 changed files with 1 additions and 1 deletions

View File

@ -14,7 +14,7 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens
Returns:
Tuples: A tuple (fp8_tensor, scale)
"""
if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]:
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
return inp
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2