|
|
|
@ -1,8 +1,10 @@
|
|
|
|
|
from typing import Any, Callable, List, Optional, Tuple, Union |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): |
|
|
|
|
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): |
|
|
|
|
r""" |
|
|
|
|
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. |
|
|
|
|
Args: |
|
|
|
@ -14,28 +16,28 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens
|
|
|
|
|
Returns: |
|
|
|
|
Tuples: A tuple (fp8_tensor, scale) |
|
|
|
|
""" |
|
|
|
|
if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]: |
|
|
|
|
return inp |
|
|
|
|
|
|
|
|
|
if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: |
|
|
|
|
raise TypeError("Only float16, bfloat16, and float32 are allowed.") |
|
|
|
|
|
|
|
|
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 |
|
|
|
|
fp8_max = torch.finfo(fp8_type).max |
|
|
|
|
|
|
|
|
|
if inp.dim() == 2: |
|
|
|
|
if scale is None: |
|
|
|
|
per_channel_max = inp.abs().max(dim=-1).values |
|
|
|
|
scale = per_channel_max |
|
|
|
|
scale_inv = 1.0 / scale |
|
|
|
|
scale_inv = scale_inv[:, None] |
|
|
|
|
ret = (scale_inv * inp).to(fp8_type) |
|
|
|
|
per_channel_max = inp.abs().max(dim=-1).values.float() |
|
|
|
|
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) |
|
|
|
|
scale = fp8_max / per_channel_max[:, None] |
|
|
|
|
else: |
|
|
|
|
if scale is None: |
|
|
|
|
per_tensor_max = inp.abs().max() |
|
|
|
|
scale = per_tensor_max |
|
|
|
|
scale_inv = 1.0 / scale |
|
|
|
|
ret = (scale_inv * inp).to(fp8_type) |
|
|
|
|
per_tensor_max = inp.abs().max().float() |
|
|
|
|
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) |
|
|
|
|
scale = fp8_max / per_tensor_max |
|
|
|
|
|
|
|
|
|
return ret, scale |
|
|
|
|
scale_inv = 1.0 / scale |
|
|
|
|
ret = (scale * inp.float()).to(fp8_type) |
|
|
|
|
return ret, scale_inv |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: |
|
|
|
|
def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: |
|
|
|
|
r""" |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -47,12 +49,13 @@ def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype)
|
|
|
|
|
torch.Tensor |
|
|
|
|
""" |
|
|
|
|
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: |
|
|
|
|
return inp |
|
|
|
|
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.") |
|
|
|
|
|
|
|
|
|
if inp.dim() == 2: |
|
|
|
|
ret = scale[:, None] * inp.to(ret_type) |
|
|
|
|
ret = scale_inv[:, None] * inp.float() |
|
|
|
|
else: |
|
|
|
|
ret = scale * inp.to(ret_type) |
|
|
|
|
return ret |
|
|
|
|
ret = scale_inv * inp.float() |
|
|
|
|
return ret.to(ret_type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: |
|
|
|
@ -69,7 +72,6 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
world_size = dist.get_world_size() |
|
|
|
|
rank = dist.get_rank() |
|
|
|
|
input_type = tensor.dtype |
|
|
|
|
input_shape = tensor.shape |
|
|
|
|
input_device = tensor.device |
|
|
|
|