[fp8] support asynchronous FP8 communication (#5997)

* fix

* fix

* fix

* support async all2all

* support async op for all gather

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6002/head
flybird11111 3 months ago committed by GitHub
parent 0978080a69
commit 597b206001
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -10,6 +10,18 @@ from torch.distributed import ReduceOp
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0")
class Handle:
def __init__(self, handles=[], remain_ops=None) -> None:
self.handles = handles
self.remain_ops = remain_ops
def wait(self):
for handle in self.handles:
handle.wait()
if self.remain_ops:
self.remain_ops()
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
@ -68,7 +80,9 @@ def cast_from_fp8(
return ret.to(ret_type)
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None:
def all_reduce_fp8(
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
) -> Optional[Handle]:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.
@ -105,6 +119,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
@ -113,19 +128,28 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
summed_out.div_(world_size)
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
dist.all_gather(scale_list, scale, group=group)
gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
gather_tensor_handle = dist.all_gather(
tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op
)
def cat_op():
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device)
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
out = torch.cat(tensor_list, dim=0)
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
if async_op:
return Handle([gather_scale_handle, gather_tensor_handle], cat_op)
else:
cat_op()
def all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> None:
) -> Optional[Handle]:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
@ -163,9 +187,11 @@ def all_to_all_single_fp8(
else:
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
dist.all_to_all(output_chunks, input_chunks, group=group)
chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
def cast_op():
cast_output_chunk = [
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
]
@ -178,6 +204,11 @@ def all_to_all_single_fp8(
outputs_shape = input_shape
output.data = tensor_out.view(outputs_shape).to(input_type)
if async_op:
return Handle([chunk_handle, scale_hanle], cast_op)
else:
cast_op()
def cast_to_fp8_pipeline(inp: Any) -> None:
"""
@ -250,7 +281,9 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
del inp["dtype"]
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
def reduce_scatter_fp8(
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
r"""
This is an in-place operation for compressed reduce_scatter using fp8.
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
@ -277,15 +310,21 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
cast_input_list.append(ret)
output_chunks.append(torch.empty_like(ret))
output_scale_list.append(torch.empty_like(scale))
dist.all_to_all(output_chunks, cast_input_list, group=group)
dist.all_to_all(output_scale_list, scale_list, group=group)
chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op)
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
def cast_op():
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(output_scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
output.data = summed_out
if async_op:
return Handle([chunk_handle, scale_handle], cast_op)
else:
cast_op()
def fp8_compress_ddp_grad_comm_hook_async(
process_group: dist.ProcessGroup,
@ -500,7 +539,8 @@ def all_gather_into_tensor_flat_fp8(
output_shape: torch.Size,
group: dist.ProcessGroup,
fp8_format: str = "e4m3",
):
async_op: bool = False,
) -> Optional[Handle]:
"""all gather into tensor in fp8 format
Args:
@ -547,15 +587,25 @@ def all_gather_into_tensor_flat_fp8(
scale = fp8_max / per_tensor_max
fp8_input = (scale * input_tensor.float()).to(fp8_type)
scale_inv = 1.0 / scale
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
tensor_handle = dist.all_gather_into_tensor(
buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group, async_op=async_op
)
def cast_op():
numel = output_shape.numel()
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
output_tensor[:numel].copy_(valid_buffer.view(-1))
if async_op:
return Handle([tensor_handle], cast_op)
else:
cast_op()
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
world_size = dist.get_world_size(group)
@ -573,17 +623,23 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
output_scale_list = [torch.empty_like(x) for x in scale_list]
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
dist.all_to_all(output_tensor_list, tensor_list, group=group)
dist.all_to_all(output_scale_list, scale_list, group=group)
tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op)
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
def cast_op():
for i in range(world_size):
scale = output_scale_list[i]
tensor = output_tensor_list[i]
tensor = tensor.view(fp8_type)
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
if async_op:
return Handle([tensor_hanle, scale_handle], cast_op)
else:
cast_op()
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group)
@ -593,14 +649,20 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
input_ = ret.view(torch.uint8)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
dist.all_gather(tensor_list, input_, group=group)
dist.all_gather(scale_list, scale, group=group)
chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op)
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
def cast_op():
for i in range(world_size):
output = tensor_list[i].view(fp8_type)
scale = scale_list[i]
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
if async_op:
return Handle([chunk_handle, scale_hanle], cast_op)
else:
cast_op()
class _LinearFp8(torch.autograd.Function):
@staticmethod

@ -10,19 +10,24 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16])
def check_all2all(shape, dtype):
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("async_op", [True, False])
def check_all2all(shape, dtype, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output = torch.empty_like(x)
output_fp8 = torch.empty_like(x)
dist.all_to_all_single(output, x, group=_get_default_group(), async_op=False)
all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=False)
origin_hanle = dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op)
fp8_handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op)
if async_op:
origin_hanle.wait()
fp8_handle.wait()
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
@parameterize("shape", [(8, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_all2all_uneven(shape, dtype):
@parameterize("async_op", [True, False])
def check_all2all_uneven(shape, dtype, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
input_split_sizes = [3, 3, 1, 1]
if dist.get_rank() in [0, 1]:
@ -33,22 +38,25 @@ def check_all2all_uneven(shape, dtype):
output_shape[0] = sum(output_split_sizes)
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
dist.all_to_all_single(
origin_hanle = dist.all_to_all_single(
output,
x,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=_get_default_group(),
async_op=False,
async_op=async_op,
)
all_to_all_single_fp8(
fp8_handle = all_to_all_single_fp8(
output_fp8,
x,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=_get_default_group(),
async_op=False,
async_op=async_op,
)
if async_op:
origin_hanle.wait()
fp8_handle.wait()
assert_close(output, output_fp8, rtol=0.1, atol=0.1)

@ -12,7 +12,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_4gpu(shape, dtype):
@parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, async_op):
world_size = dist.get_world_size()
rank = dist.get_rank()
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
@ -22,7 +23,9 @@ def check_4gpu(shape, dtype):
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
output = torch.empty_like(flat_padded_x)
chunk = flat_padded_x.chunk(world_size)[rank].clone()
all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group())
handle = all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group(), async_op=async_op)
if async_op:
handle.wait()
assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1)

@ -22,15 +22,22 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
)
@parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, dtype, fp8_format):
@parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, fp8_format, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
x_fp8 = x.clone()
dist.all_reduce(x)
all_reduce_fp8(x_fp8, fp8_format=fp8_format)
origin_handle = dist.all_reduce(x, async_op=async_op)
fp8_handle = all_reduce_fp8(x_fp8, fp8_format=fp8_format, async_op=async_op)
if async_op:
origin_handle.wait()
fp8_handle.wait()
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
dist.all_reduce(x, op=dist.ReduceOp.AVG)
all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format)
origin_handle = dist.all_reduce(x, op=dist.ReduceOp.AVG, async_op=async_op)
fp8_handle = all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format, async_op=async_op)
if async_op:
origin_handle.wait()
fp8_handle.wait()
assert_close(x, x_fp8, rtol=0.1, atol=0.1)

@ -24,13 +24,17 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
)
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, dtype, fp8_format):
@parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, fp8_format, async_op):
world_size = dist.get_world_size()
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output_list = [torch.empty_like(x) for _ in range(world_size)]
output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format)
dist.all_gather(output_list, x, group=_get_default_group())
fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op)
origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op)
if async_op:
fp8_handle.wait()
origin_hanle.wait()
assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1)

Loading…
Cancel
Save