diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index d405de2de..051ecb45a 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -1,8 +1,10 @@ +from typing import Any, Callable, List, Optional, Tuple, Union + import torch import torch.distributed as dist -def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): +def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. Args: @@ -14,28 +16,28 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens Returns: Tuples: A tuple (fp8_tensor, scale) """ - if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]: - return inp + + if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise TypeError("Only float16, bfloat16, and float32 are allowed.") + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + fp8_max = torch.finfo(fp8_type).max if inp.dim() == 2: - if scale is None: - per_channel_max = inp.abs().max(dim=-1).values - scale = per_channel_max - scale_inv = 1.0 / scale - scale_inv = scale_inv[:, None] - ret = (scale_inv * inp).to(fp8_type) + per_channel_max = inp.abs().max(dim=-1).values.float() + per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) + scale = fp8_max / per_channel_max[:, None] else: - if scale is None: - per_tensor_max = inp.abs().max() - scale = per_tensor_max - scale_inv = 1.0 / scale - ret = (scale_inv * inp).to(fp8_type) + per_tensor_max = inp.abs().max().float() + per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + scale = fp8_max / per_tensor_max - return ret, scale + scale_inv = 1.0 / scale + ret = (scale * inp.float()).to(fp8_type) + return ret, scale_inv -def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: +def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: r""" Args: @@ -47,12 +49,13 @@ def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) torch.Tensor """ if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - return inp + raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.") + if inp.dim() == 2: - ret = scale[:, None] * inp.to(ret_type) + ret = scale_inv[:, None] * inp.float() else: - ret = scale * inp.to(ret_type) - return ret + ret = scale_inv * inp.float() + return ret.to(ret_type) def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: @@ -69,7 +72,6 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: """ world_size = dist.get_world_size() - rank = dist.get_rank() input_type = tensor.dtype input_shape = tensor.shape input_device = tensor.device