Browse Source

fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8
pull/5885/head
HangXu 5 months ago committed by GitHub
parent
commit
f5a52e1600
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 105
      colossalai/quantization/fp8.py

105
colossalai/quantization/fp8.py

@ -0,0 +1,105 @@
import torch
import torch.distributed as dist
def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor.
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
fp8_format: e4m3 or e5m2
Returns:
Tuples: A tuple (fp8_tensor, scale)
"""
if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]:
return inp
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
if inp.dim() == 2:
if scale is None:
per_channel_max = inp.abs().max(dim=-1).values
scale = per_channel_max
scale_inv = 1.0 / scale
scale_inv = scale_inv[:, None]
ret = (scale_inv * inp).to(fp8_type)
else:
if scale is None:
per_tensor_max = inp.abs().max()
scale = per_tensor_max
scale_inv = 1.0 / scale
ret = (scale_inv * inp).to(fp8_type)
return ret, scale
def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
r"""
Args:
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
scale: scaling factor returned by cast_to_fp8 function.
ret_type: the datatype of the returned tensor.
Returns:
torch.Tensor
"""
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
return inp
if inp.dim() == 2:
ret = scale[:, None] * inp.to(ret_type)
else:
ret = scale * inp.to(ret_type)
return ret
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.
Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
Returns:
None
"""
world_size = dist.get_world_size()
rank = dist.get_rank()
input_type = tensor.dtype
input_shape = tensor.shape
input_device = tensor.device
input_size = tensor.numel()
tensor = tensor.flatten()
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format)
inp = ret.view(torch.uint8)
input_chunks = list(torch.chunk(inp, world_size, dim=0))
if dist.get_rank() == world_size - 1:
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
else:
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
dist.all_to_all(output_chunks, input_chunks)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
dist.all_gather(scale_list, scale)
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8))
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_out = torch.cat(tensor_list, dim=0)
tensor.data = tensor_out.view(input_shape).to(input_type)
Loading…
Cancel
Save