mirror of https://github.com/hpcaitech/ColossalAI
fix typo
parent
e17f835df7
commit
dbfa7d39fc
|
@ -14,7 +14,7 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens
|
|||
Returns:
|
||||
Tuples: A tuple (fp8_tensor, scale)
|
||||
"""
|
||||
if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]:
|
||||
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
return inp
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
|
|
Loading…
Reference in New Issue