From f5a52e1600e59d461a11abbf05585c0e0b19c906 Mon Sep 17 00:00:00 2001 From: HangXu Date: Mon, 1 Jul 2024 13:44:21 +0800 Subject: [PATCH] fp8 operators for compressed communication cast_to_fp8, cast_from_fp8, all_reduce_fp8 --- colossalai/quantization/fp8.py | 105 +++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 colossalai/quantization/fp8.py diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py new file mode 100644 index 000000000..d405de2de --- /dev/null +++ b/colossalai/quantization/fp8.py @@ -0,0 +1,105 @@ +import torch +import torch.distributed as dist + + +def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): + r""" + casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. + Args: + inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor. + scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling + is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied. + fp8_format: e4m3 or e5m2 + + Returns: + Tuples: A tuple (fp8_tensor, scale) + """ + if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]: + return inp + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + if inp.dim() == 2: + if scale is None: + per_channel_max = inp.abs().max(dim=-1).values + scale = per_channel_max + scale_inv = 1.0 / scale + scale_inv = scale_inv[:, None] + ret = (scale_inv * inp).to(fp8_type) + else: + if scale is None: + per_tensor_max = inp.abs().max() + scale = per_tensor_max + scale_inv = 1.0 / scale + ret = (scale_inv * inp).to(fp8_type) + + return ret, scale + + +def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: + r""" + + Args: + inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2]. + scale: scaling factor returned by cast_to_fp8 function. + ret_type: the datatype of the returned tensor. + + Returns: + torch.Tensor + """ + if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + return inp + if inp.dim() == 2: + ret = scale[:, None] * inp.to(ret_type) + else: + ret = scale * inp.to(ret_type) + return ret + + +def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: + r""" + This is an in-place operation for compressed all_reduce using fp8. + It works like dist.all_reduce but during communication the data is cast to fp8 format. + + Args: + tensor: torch.Tensor in fp32, fp16, bf16 datatype. + fp8_format: e4m3 or e5m2 + + Returns: + None + """ + + world_size = dist.get_world_size() + rank = dist.get_rank() + input_type = tensor.dtype + input_shape = tensor.shape + input_device = tensor.device + input_size = tensor.numel() + tensor = tensor.flatten() + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format) + + inp = ret.view(torch.uint8) + input_chunks = list(torch.chunk(inp, world_size, dim=0)) + if dist.get_rank() == world_size - 1: + output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] + else: + output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] + dist.all_to_all(output_chunks, input_chunks) + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + dist.all_gather(scale_list, scale) + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + + summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) + dist.all_gather(scale_list, scale) + + tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0)) + dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8)) + for i in range(world_size): + tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] + tensor_out = torch.cat(tensor_list, dim=0) + tensor.data = tensor_out.view(input_shape).to(input_type) \ No newline at end of file