mirror of https://github.com/hpcaitech/ColossalAI
HangXu
5 months ago
committed by
GitHub
1 changed files with 105 additions and 0 deletions
@ -0,0 +1,105 @@ |
|||||||
|
import torch |
||||||
|
import torch.distributed as dist |
||||||
|
|
||||||
|
|
||||||
|
def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): |
||||||
|
r""" |
||||||
|
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. |
||||||
|
Args: |
||||||
|
inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor. |
||||||
|
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling |
||||||
|
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied. |
||||||
|
fp8_format: e4m3 or e5m2 |
||||||
|
|
||||||
|
Returns: |
||||||
|
Tuples: A tuple (fp8_tensor, scale) |
||||||
|
""" |
||||||
|
if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]: |
||||||
|
return inp |
||||||
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 |
||||||
|
|
||||||
|
if inp.dim() == 2: |
||||||
|
if scale is None: |
||||||
|
per_channel_max = inp.abs().max(dim=-1).values |
||||||
|
scale = per_channel_max |
||||||
|
scale_inv = 1.0 / scale |
||||||
|
scale_inv = scale_inv[:, None] |
||||||
|
ret = (scale_inv * inp).to(fp8_type) |
||||||
|
else: |
||||||
|
if scale is None: |
||||||
|
per_tensor_max = inp.abs().max() |
||||||
|
scale = per_tensor_max |
||||||
|
scale_inv = 1.0 / scale |
||||||
|
ret = (scale_inv * inp).to(fp8_type) |
||||||
|
|
||||||
|
return ret, scale |
||||||
|
|
||||||
|
|
||||||
|
def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: |
||||||
|
r""" |
||||||
|
|
||||||
|
Args: |
||||||
|
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2]. |
||||||
|
scale: scaling factor returned by cast_to_fp8 function. |
||||||
|
ret_type: the datatype of the returned tensor. |
||||||
|
|
||||||
|
Returns: |
||||||
|
torch.Tensor |
||||||
|
""" |
||||||
|
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: |
||||||
|
return inp |
||||||
|
if inp.dim() == 2: |
||||||
|
ret = scale[:, None] * inp.to(ret_type) |
||||||
|
else: |
||||||
|
ret = scale * inp.to(ret_type) |
||||||
|
return ret |
||||||
|
|
||||||
|
|
||||||
|
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: |
||||||
|
r""" |
||||||
|
This is an in-place operation for compressed all_reduce using fp8. |
||||||
|
It works like dist.all_reduce but during communication the data is cast to fp8 format. |
||||||
|
|
||||||
|
Args: |
||||||
|
tensor: torch.Tensor in fp32, fp16, bf16 datatype. |
||||||
|
fp8_format: e4m3 or e5m2 |
||||||
|
|
||||||
|
Returns: |
||||||
|
None |
||||||
|
""" |
||||||
|
|
||||||
|
world_size = dist.get_world_size() |
||||||
|
rank = dist.get_rank() |
||||||
|
input_type = tensor.dtype |
||||||
|
input_shape = tensor.shape |
||||||
|
input_device = tensor.device |
||||||
|
input_size = tensor.numel() |
||||||
|
tensor = tensor.flatten() |
||||||
|
|
||||||
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 |
||||||
|
|
||||||
|
ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format) |
||||||
|
|
||||||
|
inp = ret.view(torch.uint8) |
||||||
|
input_chunks = list(torch.chunk(inp, world_size, dim=0)) |
||||||
|
if dist.get_rank() == world_size - 1: |
||||||
|
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] |
||||||
|
else: |
||||||
|
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] |
||||||
|
dist.all_to_all(output_chunks, input_chunks) |
||||||
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] |
||||||
|
dist.all_gather(scale_list, scale) |
||||||
|
summed_out = torch.zeros_like(output_chunks[0]).to(input_type) |
||||||
|
for scale, out in zip(scale_list, output_chunks): |
||||||
|
out = out.view(fp8_type) |
||||||
|
summed_out += cast_from_fp8(out, scale, input_type) |
||||||
|
|
||||||
|
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) |
||||||
|
dist.all_gather(scale_list, scale) |
||||||
|
|
||||||
|
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0)) |
||||||
|
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8)) |
||||||
|
for i in range(world_size): |
||||||
|
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] |
||||||
|
tensor_out = torch.cat(tensor_list, dim=0) |
||||||
|
tensor.data = tensor_out.view(input_shape).to(input_type) |
Loading…
Reference in new issue