mirror of https://github.com/hpcaitech/ColossalAI
[fp8] support all-gather flat tensor (#5932)
parent
62661cde22
commit
5fd0592767
|
@ -1,5 +1,6 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
@ -202,3 +203,78 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
|
||||||
out = out.view(fp8_type)
|
out = out.view(fp8_type)
|
||||||
summed_out += cast_from_fp8(out, scale, input_type)
|
summed_out += cast_from_fp8(out, scale, input_type)
|
||||||
output.data = summed_out
|
output.data = summed_out
|
||||||
|
|
||||||
|
|
||||||
|
def split_chunk_by_channel(
|
||||||
|
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
|
||||||
|
):
|
||||||
|
offset = chunk.numel() * rank
|
||||||
|
end = offset + chunk.numel()
|
||||||
|
break_points = [x for x in range(0, channel_size * num_channels + 1, channel_size) if offset <= x <= end]
|
||||||
|
if len(break_points) == 0 or break_points[0] > offset:
|
||||||
|
break_points.insert(0, offset)
|
||||||
|
if break_points[-1] < end:
|
||||||
|
break_points.append(end)
|
||||||
|
sizes = [b - a for a, b in zip(break_points[:-1], break_points[1:])]
|
||||||
|
return chunk.split(sizes)
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_into_tensor_flat_fp8(
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
output_shape: torch.Size,
|
||||||
|
group: dist.ProcessGroup,
|
||||||
|
fp8_format: str = "e4m3",
|
||||||
|
):
|
||||||
|
"""all gather into tensor in fp8 format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_tensor (torch.Tensor): output tensor, which is flattened
|
||||||
|
input_tensor (torch.Tensor): input tensor, which is flattened
|
||||||
|
group (dist.ProcessGroup): process group
|
||||||
|
fp8_format (str, optional): fp8 format, e4m3 or e5m2. Defaults to "e4m3".
|
||||||
|
"""
|
||||||
|
assert input_tensor.dim() == 1 and output_tensor.dim() == 1, "input/output tensor should be flattened"
|
||||||
|
world_size = dist.get_world_size(group)
|
||||||
|
assert (
|
||||||
|
output_tensor.numel() == input_tensor.numel() * world_size
|
||||||
|
), "output tensor size should be world_size times of input tensor size"
|
||||||
|
|
||||||
|
input_type = output_tensor.dtype
|
||||||
|
|
||||||
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||||
|
fp8_max = torch.finfo(fp8_type).max
|
||||||
|
|
||||||
|
if len(output_shape) == 2:
|
||||||
|
per_channel_max = torch.zeros(output_shape[0], device=output_tensor.device, dtype=torch.float)
|
||||||
|
num_channels, channel_size = output_shape
|
||||||
|
rank = dist.get_rank(group)
|
||||||
|
channel_start_idx = (input_tensor.numel() * rank) // channel_size
|
||||||
|
per_channel_splits = split_chunk_by_channel(input_tensor, channel_size, num_channels, rank, world_size)
|
||||||
|
for i, per_channel_split in enumerate(per_channel_splits):
|
||||||
|
idx = i + channel_start_idx
|
||||||
|
if idx < num_channels:
|
||||||
|
per_channel_max[idx] = per_channel_split.abs().max().float()
|
||||||
|
dist.all_reduce(per_channel_max, op=dist.ReduceOp.MAX, group=group)
|
||||||
|
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
||||||
|
scale = fp8_max / per_channel_max
|
||||||
|
fp8_input = input_tensor.float()
|
||||||
|
fp8_per_channel_splits = split_chunk_by_channel(fp8_input, channel_size, num_channels, rank, world_size)
|
||||||
|
for i, per_channel_split in enumerate(fp8_per_channel_splits):
|
||||||
|
idx = i + channel_start_idx
|
||||||
|
if idx < num_channels:
|
||||||
|
per_channel_split.mul_(scale[idx])
|
||||||
|
fp8_input = fp8_input.to(fp8_type)
|
||||||
|
else:
|
||||||
|
per_tensor_max = input_tensor.abs().max().float()
|
||||||
|
dist.all_reduce(per_tensor_max, op=dist.ReduceOp.MAX, group=group)
|
||||||
|
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||||
|
scale = fp8_max / per_tensor_max
|
||||||
|
fp8_input = (scale * input_tensor.float()).to(fp8_type)
|
||||||
|
scale_inv = 1.0 / scale
|
||||||
|
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
|
||||||
|
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
|
||||||
|
numel = np.prod(output_shape)
|
||||||
|
valid_buffer = buffer[:numel].reshape(output_shape)
|
||||||
|
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type)
|
||||||
|
output_tensor[:numel].copy_(valid_buffer.view(-1))
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
from colossalai import launch
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
|
||||||
|
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||||
|
def check_4gpu(shape, dtype):
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
rank = dist.get_rank()
|
||||||
|
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||||
|
flat_padded_x = x.view(-1)
|
||||||
|
if flat_padded_x.size(0) % world_size != 0:
|
||||||
|
pad_size = world_size - flat_padded_x.size(0) % world_size
|
||||||
|
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
||||||
|
output = torch.empty_like(flat_padded_x)
|
||||||
|
chunk = flat_padded_x.chunk(world_size)[rank].clone()
|
||||||
|
all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group())
|
||||||
|
assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
check_4gpu()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_all_gather():
|
||||||
|
spawn(run_dist, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_all_gather()
|
Loading…
Reference in New Issue