mirror of https://github.com/hpcaitech/ColossalAI
fix scaling algorithm in FP8 casting
parent
f5a52e1600
commit
1e1959467e
|
@ -1,8 +1,10 @@
|
|||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
|
||||
r"""
|
||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||
Args:
|
||||
|
@ -14,28 +16,28 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens
|
|||
Returns:
|
||||
Tuples: A tuple (fp8_tensor, scale)
|
||||
"""
|
||||
if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]:
|
||||
return inp
|
||||
|
||||
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
raise TypeError("Only float16, bfloat16, and float32 are allowed.")
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
fp8_max = torch.finfo(fp8_type).max
|
||||
|
||||
if inp.dim() == 2:
|
||||
if scale is None:
|
||||
per_channel_max = inp.abs().max(dim=-1).values
|
||||
scale = per_channel_max
|
||||
scale_inv = 1.0 / scale
|
||||
scale_inv = scale_inv[:, None]
|
||||
ret = (scale_inv * inp).to(fp8_type)
|
||||
per_channel_max = inp.abs().max(dim=-1).values.float()
|
||||
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
||||
scale = fp8_max / per_channel_max[:, None]
|
||||
else:
|
||||
if scale is None:
|
||||
per_tensor_max = inp.abs().max()
|
||||
scale = per_tensor_max
|
||||
per_tensor_max = inp.abs().max().float()
|
||||
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||
scale = fp8_max / per_tensor_max
|
||||
|
||||
scale_inv = 1.0 / scale
|
||||
ret = (scale_inv * inp).to(fp8_type)
|
||||
|
||||
return ret, scale
|
||||
ret = (scale * inp.float()).to(fp8_type)
|
||||
return ret, scale_inv
|
||||
|
||||
|
||||
def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
|
||||
def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
|
||||
r"""
|
||||
|
||||
Args:
|
||||
|
@ -47,12 +49,13 @@ def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype)
|
|||
torch.Tensor
|
||||
"""
|
||||
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
return inp
|
||||
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
|
||||
|
||||
if inp.dim() == 2:
|
||||
ret = scale[:, None] * inp.to(ret_type)
|
||||
ret = scale_inv[:, None] * inp.float()
|
||||
else:
|
||||
ret = scale * inp.to(ret_type)
|
||||
return ret
|
||||
ret = scale_inv * inp.float()
|
||||
return ret.to(ret_type)
|
||||
|
||||
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||
|
@ -69,7 +72,6 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
|||
"""
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
input_type = tensor.dtype
|
||||
input_shape = tensor.shape
|
||||
input_device = tensor.device
|
||||
|
|
Loading…
Reference in New Issue