diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index fe87e317d..4dd7db236 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -10,6 +10,18 @@ from torch.distributed import ReduceOp SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0") +class Handle: + def __init__(self, handles=[], remain_ops=None) -> None: + self.handles = handles + self.remain_ops = remain_ops + + def wait(self): + for handle in self.handles: + handle.wait() + if self.remain_ops: + self.remain_ops() + + def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]: r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. @@ -68,7 +80,9 @@ def cast_from_fp8( return ret.to(ret_type) -def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None: +def all_reduce_fp8( + tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False +) -> Optional[Handle]: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_reduce but during communication the data is cast to fp8 format. @@ -105,6 +119,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] dist.all_gather(scale_list, scale, group=group) summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(scale_list, output_chunks): out = out.view(fp8_type) summed_out += cast_from_fp8(out, scale, input_type) @@ -113,19 +128,28 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro summed_out.div_(world_size) summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) - dist.all_gather(scale_list, scale, group=group) + gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)] - dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group) - for i in range(world_size): - tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device) - out = torch.cat(tensor_list, dim=0) - tensor.copy_(out[:input_size].view(input_shape).to(input_type)) + gather_tensor_handle = dist.all_gather( + tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op + ) + + def cat_op(): + for i in range(world_size): + tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] + out = torch.cat(tensor_list, dim=0) + tensor.copy_(out[:input_size].view(input_shape).to(input_type)) + + if async_op: + return Handle([gather_scale_handle, gather_tensor_handle], cat_op) + else: + cat_op() def all_to_all_single_fp8( output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False -) -> None: +) -> Optional[Handle]: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_to_all_single but during communication the data is cast to fp8 format. @@ -163,20 +187,27 @@ def all_to_all_single_fp8( else: output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] - dist.all_to_all(output_chunks, input_chunks, group=group) + chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] - dist.all_gather(scale_list, scale, group=group) - cast_output_chunk = [ - cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks) - ] + scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) - tensor_out = torch.cat(cast_output_chunk, dim=0) - outputs_shape = list(input_shape) - if output_split_sizes is not None: - outputs_shape[0] = sum(output_split_sizes) + def cast_op(): + cast_output_chunk = [ + cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks) + ] + + tensor_out = torch.cat(cast_output_chunk, dim=0) + outputs_shape = list(input_shape) + if output_split_sizes is not None: + outputs_shape[0] = sum(output_split_sizes) + else: + outputs_shape = input_shape + output.data = tensor_out.view(outputs_shape).to(input_type) + + if async_op: + return Handle([chunk_handle, scale_hanle], cast_op) else: - outputs_shape = input_shape - output.data = tensor_out.view(outputs_shape).to(input_type) + cast_op() def cast_to_fp8_pipeline(inp: Any) -> None: @@ -250,7 +281,9 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: del inp["dtype"] -def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None: +def reduce_scatter_fp8( + output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False +) -> Optional[Handle]: r""" This is an in-place operation for compressed reduce_scatter using fp8. It works like dist.reduce_scatter but during communication the data is cast to fp8 format. @@ -277,14 +310,20 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2 cast_input_list.append(ret) output_chunks.append(torch.empty_like(ret)) output_scale_list.append(torch.empty_like(scale)) - dist.all_to_all(output_chunks, cast_input_list, group=group) - dist.all_to_all(output_scale_list, scale_list, group=group) + chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op) + scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op) - summed_out = torch.zeros_like(output_chunks[0]).to(input_type) - for scale, out in zip(output_scale_list, output_chunks): - out = out.view(fp8_type) - summed_out += cast_from_fp8(out, scale, input_type) - output.data = summed_out + def cast_op(): + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(output_scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + output.data = summed_out + + if async_op: + return Handle([chunk_handle, scale_handle], cast_op) + else: + cast_op() def fp8_compress_ddp_grad_comm_hook_async( @@ -500,7 +539,8 @@ def all_gather_into_tensor_flat_fp8( output_shape: torch.Size, group: dist.ProcessGroup, fp8_format: str = "e4m3", -): + async_op: bool = False, +) -> Optional[Handle]: """all gather into tensor in fp8 format Args: @@ -547,15 +587,25 @@ def all_gather_into_tensor_flat_fp8( scale = fp8_max / per_tensor_max fp8_input = (scale * input_tensor.float()).to(fp8_type) scale_inv = 1.0 / scale + buffer = torch.empty_like(output_tensor, dtype=fp8_type) - dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group) - numel = output_shape.numel() - valid_buffer = buffer[:numel].reshape(output_shape) - valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2)) - output_tensor[:numel].copy_(valid_buffer.view(-1)) + tensor_handle = dist.all_gather_into_tensor( + buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group, async_op=async_op + ) + + def cast_op(): + numel = output_shape.numel() + valid_buffer = buffer[:numel].reshape(output_shape) + valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2)) + output_tensor[:numel].copy_(valid_buffer.view(-1)) + + if async_op: + return Handle([tensor_handle], cast_op) + else: + cast_op() -def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): +def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False): world_size = dist.get_world_size(group) @@ -573,17 +623,23 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): output_scale_list = [torch.empty_like(x) for x in scale_list] output_tensor_list = [torch.empty_like(x) for x in tensor_list] - dist.all_to_all(output_tensor_list, tensor_list, group=group) - dist.all_to_all(output_scale_list, scale_list, group=group) + tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op) + scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op) - for i in range(world_size): - scale = output_scale_list[i] - tensor = output_tensor_list[i] - tensor = tensor.view(fp8_type) - output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) + def cast_op(): + for i in range(world_size): + scale = output_scale_list[i] + tensor = output_tensor_list[i] + tensor = tensor.view(fp8_type) + output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) + + if async_op: + return Handle([tensor_hanle, scale_handle], cast_op) + else: + cast_op() -def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): +def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]: world_size = dist.get_world_size(group) @@ -593,13 +649,19 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): input_ = ret.view(torch.uint8) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] - dist.all_gather(tensor_list, input_, group=group) - dist.all_gather(scale_list, scale, group=group) + chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op) + scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op) - for i in range(world_size): - output = tensor_list[i].view(fp8_type) - scale = scale_list[i] - output_list[i].copy_(cast_from_fp8(output, scale, input_type)) + def cast_op(): + for i in range(world_size): + output = tensor_list[i].view(fp8_type) + scale = scale_list[i] + output_list[i].copy_(cast_from_fp8(output, scale, input_type)) + + if async_op: + return Handle([chunk_handle, scale_hanle], cast_op) + else: + cast_op() class _LinearFp8(torch.autograd.Function): diff --git a/tests/test_fp8/test_all_to_all_single.py b/tests/test_fp8/test_all_to_all_single.py index 88becd3f0..722cbce9a 100644 --- a/tests/test_fp8/test_all_to_all_single.py +++ b/tests/test_fp8/test_all_to_all_single.py @@ -10,19 +10,24 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)]) -@parameterize("dtype", [torch.bfloat16]) -def check_all2all(shape, dtype): +@parameterize("dtype", [torch.bfloat16, torch.float16]) +@parameterize("async_op", [True, False]) +def check_all2all(shape, dtype, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) output = torch.empty_like(x) output_fp8 = torch.empty_like(x) - dist.all_to_all_single(output, x, group=_get_default_group(), async_op=False) - all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=False) + origin_hanle = dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op) + fp8_handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op) + if async_op: + origin_hanle.wait() + fp8_handle.wait() assert_close(output, output_fp8, rtol=0.1, atol=0.1) @parameterize("shape", [(8, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_all2all_uneven(shape, dtype): +@parameterize("async_op", [True, False]) +def check_all2all_uneven(shape, dtype, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) input_split_sizes = [3, 3, 1, 1] if dist.get_rank() in [0, 1]: @@ -33,22 +38,25 @@ def check_all2all_uneven(shape, dtype): output_shape[0] = sum(output_split_sizes) output = torch.empty(output_shape, device=x.device, dtype=x.dtype) output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype) - dist.all_to_all_single( + origin_hanle = dist.all_to_all_single( output, x, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=_get_default_group(), - async_op=False, + async_op=async_op, ) - all_to_all_single_fp8( + fp8_handle = all_to_all_single_fp8( output_fp8, x, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=_get_default_group(), - async_op=False, + async_op=async_op, ) + if async_op: + origin_hanle.wait() + fp8_handle.wait() assert_close(output, output_fp8, rtol=0.1, atol=0.1) diff --git a/tests/test_fp8/test_fp8_allgather_flat.py b/tests/test_fp8/test_fp8_allgather_flat.py index 35e8796c2..2d43e5bd5 100644 --- a/tests/test_fp8/test_fp8_allgather_flat.py +++ b/tests/test_fp8/test_fp8_allgather_flat.py @@ -12,7 +12,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_4gpu(shape, dtype): +@parameterize("async_op", [True, False]) +def check_4gpu(shape, dtype, async_op): world_size = dist.get_world_size() rank = dist.get_rank() x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) @@ -22,7 +23,9 @@ def check_4gpu(shape, dtype): flat_padded_x = F.pad(flat_padded_x, (0, pad_size)) output = torch.empty_like(flat_padded_x) chunk = flat_padded_x.chunk(world_size)[rank].clone() - all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group()) + handle = all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group(), async_op=async_op) + if async_op: + handle.wait() assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1) diff --git a/tests/test_fp8/test_fp8_allreduce.py b/tests/test_fp8/test_fp8_allreduce.py index c23959b5d..ccc43ed29 100644 --- a/tests/test_fp8/test_fp8_allreduce.py +++ b/tests/test_fp8/test_fp8_allreduce.py @@ -22,15 +22,22 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn ) @parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize("fp8_format", ["e4m3", "e5m2"]) -def check_4gpu(shape, dtype, fp8_format): +@parameterize("async_op", [True, False]) +def check_4gpu(shape, dtype, fp8_format, async_op): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) x_fp8 = x.clone() - dist.all_reduce(x) - all_reduce_fp8(x_fp8, fp8_format=fp8_format) + origin_handle = dist.all_reduce(x, async_op=async_op) + fp8_handle = all_reduce_fp8(x_fp8, fp8_format=fp8_format, async_op=async_op) + if async_op: + origin_handle.wait() + fp8_handle.wait() assert_close(x, x_fp8, rtol=0.1, atol=0.1) - dist.all_reduce(x, op=dist.ReduceOp.AVG) - all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format) + origin_handle = dist.all_reduce(x, op=dist.ReduceOp.AVG, async_op=async_op) + fp8_handle = all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format, async_op=async_op) + if async_op: + origin_handle.wait() + fp8_handle.wait() assert_close(x, x_fp8, rtol=0.1, atol=0.1) diff --git a/tests/test_fp8/test_fp8_gather.py b/tests/test_fp8/test_fp8_gather.py index 79d1d4ea4..40c2ccb9a 100644 --- a/tests/test_fp8/test_fp8_gather.py +++ b/tests/test_fp8/test_fp8_gather.py @@ -24,13 +24,17 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn ) @parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("fp8_format", ["e4m3", "e5m2"]) -def check_4gpu(shape, dtype, fp8_format): +@parameterize("async_op", [True, False]) +def check_4gpu(shape, dtype, fp8_format, async_op): world_size = dist.get_world_size() x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) output_list = [torch.empty_like(x) for _ in range(world_size)] output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)] - gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format) - dist.all_gather(output_list, x, group=_get_default_group()) + fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op) + origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op) + if async_op: + fp8_handle.wait() + origin_hanle.wait() assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1)