Browse Source

fix scaling algorithm in FP8 casting

pull/5885/head
BurkeHulk 4 months ago
parent
commit
1e1959467e
  1. 44
      colossalai/quantization/fp8.py

44
colossalai/quantization/fp8.py

@ -1,8 +1,10 @@
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
@ -14,28 +16,28 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens
Returns:
Tuples: A tuple (fp8_tensor, scale)
"""
if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]:
return inp
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise TypeError("Only float16, bfloat16, and float32 are allowed.")
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
if inp.dim() == 2:
if scale is None:
per_channel_max = inp.abs().max(dim=-1).values
scale = per_channel_max
scale_inv = 1.0 / scale
scale_inv = scale_inv[:, None]
ret = (scale_inv * inp).to(fp8_type)
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
else:
if scale is None:
per_tensor_max = inp.abs().max()
scale = per_tensor_max
scale_inv = 1.0 / scale
ret = (scale_inv * inp).to(fp8_type)
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
return ret, scale
scale_inv = 1.0 / scale
ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv
def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
r"""
Args:
@ -47,12 +49,13 @@ def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype)
torch.Tensor
"""
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
return inp
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")
if inp.dim() == 2:
ret = scale[:, None] * inp.to(ret_type)
ret = scale_inv[:, None] * inp.float()
else:
ret = scale * inp.to(ret_type)
return ret
ret = scale_inv * inp.float()
return ret.to(ret_type)
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
@ -69,7 +72,6 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
"""
world_size = dist.get_world_size()
rank = dist.get_rank()
input_type = tensor.dtype
input_shape = tensor.shape
input_device = tensor.device

Loading…
Cancel
Save