[fp8] support asynchronous FP8 communication (#5997)

* fix

* fix

* fix

* support async all2all

* support async op for all gather

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6002/head
flybird11111 3 months ago committed by GitHub
parent 0978080a69
commit 597b206001
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -10,6 +10,18 @@ from torch.distributed import ReduceOp
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0") SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0")
class Handle:
def __init__(self, handles=[], remain_ops=None) -> None:
self.handles = handles
self.remain_ops = remain_ops
def wait(self):
for handle in self.handles:
handle.wait()
if self.remain_ops:
self.remain_ops()
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]: def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
r""" r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
@ -68,7 +80,9 @@ def cast_from_fp8(
return ret.to(ret_type) return ret.to(ret_type)
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None: def all_reduce_fp8(
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
) -> Optional[Handle]:
r""" r"""
This is an in-place operation for compressed all_reduce using fp8. This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format. It works like dist.all_reduce but during communication the data is cast to fp8 format.
@ -105,6 +119,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group) dist.all_gather(scale_list, scale, group=group)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type) summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks): for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type) out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type) summed_out += cast_from_fp8(out, scale, input_type)
@ -113,19 +128,28 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
summed_out.div_(world_size) summed_out.div_(world_size)
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
dist.all_gather(scale_list, scale, group=group) gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)] tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group) gather_tensor_handle = dist.all_gather(
for i in range(world_size): tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device) )
out = torch.cat(tensor_list, dim=0)
tensor.copy_(out[:input_size].view(input_shape).to(input_type)) def cat_op():
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
out = torch.cat(tensor_list, dim=0)
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
if async_op:
return Handle([gather_scale_handle, gather_tensor_handle], cat_op)
else:
cat_op()
def all_to_all_single_fp8( def all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> None: ) -> Optional[Handle]:
r""" r"""
This is an in-place operation for compressed all_reduce using fp8. This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_to_all_single but during communication the data is cast to fp8 format. It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
@ -163,20 +187,27 @@ def all_to_all_single_fp8(
else: else:
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
dist.all_to_all(output_chunks, input_chunks, group=group) chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group) scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
cast_output_chunk = [
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
]
tensor_out = torch.cat(cast_output_chunk, dim=0) def cast_op():
outputs_shape = list(input_shape) cast_output_chunk = [
if output_split_sizes is not None: cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
outputs_shape[0] = sum(output_split_sizes) ]
tensor_out = torch.cat(cast_output_chunk, dim=0)
outputs_shape = list(input_shape)
if output_split_sizes is not None:
outputs_shape[0] = sum(output_split_sizes)
else:
outputs_shape = input_shape
output.data = tensor_out.view(outputs_shape).to(input_type)
if async_op:
return Handle([chunk_handle, scale_hanle], cast_op)
else: else:
outputs_shape = input_shape cast_op()
output.data = tensor_out.view(outputs_shape).to(input_type)
def cast_to_fp8_pipeline(inp: Any) -> None: def cast_to_fp8_pipeline(inp: Any) -> None:
@ -250,7 +281,9 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
del inp["dtype"] del inp["dtype"]
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None: def reduce_scatter_fp8(
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
) -> Optional[Handle]:
r""" r"""
This is an in-place operation for compressed reduce_scatter using fp8. This is an in-place operation for compressed reduce_scatter using fp8.
It works like dist.reduce_scatter but during communication the data is cast to fp8 format. It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
@ -277,14 +310,20 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
cast_input_list.append(ret) cast_input_list.append(ret)
output_chunks.append(torch.empty_like(ret)) output_chunks.append(torch.empty_like(ret))
output_scale_list.append(torch.empty_like(scale)) output_scale_list.append(torch.empty_like(scale))
dist.all_to_all(output_chunks, cast_input_list, group=group) chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op)
dist.all_to_all(output_scale_list, scale_list, group=group) scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type) def cast_op():
for scale, out in zip(output_scale_list, output_chunks): summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
out = out.view(fp8_type) for scale, out in zip(output_scale_list, output_chunks):
summed_out += cast_from_fp8(out, scale, input_type) out = out.view(fp8_type)
output.data = summed_out summed_out += cast_from_fp8(out, scale, input_type)
output.data = summed_out
if async_op:
return Handle([chunk_handle, scale_handle], cast_op)
else:
cast_op()
def fp8_compress_ddp_grad_comm_hook_async( def fp8_compress_ddp_grad_comm_hook_async(
@ -500,7 +539,8 @@ def all_gather_into_tensor_flat_fp8(
output_shape: torch.Size, output_shape: torch.Size,
group: dist.ProcessGroup, group: dist.ProcessGroup,
fp8_format: str = "e4m3", fp8_format: str = "e4m3",
): async_op: bool = False,
) -> Optional[Handle]:
"""all gather into tensor in fp8 format """all gather into tensor in fp8 format
Args: Args:
@ -547,15 +587,25 @@ def all_gather_into_tensor_flat_fp8(
scale = fp8_max / per_tensor_max scale = fp8_max / per_tensor_max
fp8_input = (scale * input_tensor.float()).to(fp8_type) fp8_input = (scale * input_tensor.float()).to(fp8_type)
scale_inv = 1.0 / scale scale_inv = 1.0 / scale
buffer = torch.empty_like(output_tensor, dtype=fp8_type) buffer = torch.empty_like(output_tensor, dtype=fp8_type)
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group) tensor_handle = dist.all_gather_into_tensor(
numel = output_shape.numel() buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group, async_op=async_op
valid_buffer = buffer[:numel].reshape(output_shape) )
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
output_tensor[:numel].copy_(valid_buffer.view(-1)) def cast_op():
numel = output_shape.numel()
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
output_tensor[:numel].copy_(valid_buffer.view(-1))
if async_op:
return Handle([tensor_handle], cast_op)
else:
cast_op()
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
world_size = dist.get_world_size(group) world_size = dist.get_world_size(group)
@ -573,17 +623,23 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
output_scale_list = [torch.empty_like(x) for x in scale_list] output_scale_list = [torch.empty_like(x) for x in scale_list]
output_tensor_list = [torch.empty_like(x) for x in tensor_list] output_tensor_list = [torch.empty_like(x) for x in tensor_list]
dist.all_to_all(output_tensor_list, tensor_list, group=group) tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op)
dist.all_to_all(output_scale_list, scale_list, group=group) scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
for i in range(world_size): def cast_op():
scale = output_scale_list[i] for i in range(world_size):
tensor = output_tensor_list[i] scale = output_scale_list[i]
tensor = tensor.view(fp8_type) tensor = output_tensor_list[i]
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) tensor = tensor.view(fp8_type)
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
if async_op:
return Handle([tensor_hanle, scale_handle], cast_op)
else:
cast_op()
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
world_size = dist.get_world_size(group) world_size = dist.get_world_size(group)
@ -593,13 +649,19 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
input_ = ret.view(torch.uint8) input_ = ret.view(torch.uint8)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
dist.all_gather(tensor_list, input_, group=group) chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op)
dist.all_gather(scale_list, scale, group=group) scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
for i in range(world_size): def cast_op():
output = tensor_list[i].view(fp8_type) for i in range(world_size):
scale = scale_list[i] output = tensor_list[i].view(fp8_type)
output_list[i].copy_(cast_from_fp8(output, scale, input_type)) scale = scale_list[i]
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
if async_op:
return Handle([chunk_handle, scale_hanle], cast_op)
else:
cast_op()
class _LinearFp8(torch.autograd.Function): class _LinearFp8(torch.autograd.Function):

@ -10,19 +10,24 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)]) @parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16]) @parameterize("dtype", [torch.bfloat16, torch.float16])
def check_all2all(shape, dtype): @parameterize("async_op", [True, False])
def check_all2all(shape, dtype, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output = torch.empty_like(x) output = torch.empty_like(x)
output_fp8 = torch.empty_like(x) output_fp8 = torch.empty_like(x)
dist.all_to_all_single(output, x, group=_get_default_group(), async_op=False) origin_hanle = dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op)
all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=False) fp8_handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op)
if async_op:
origin_hanle.wait()
fp8_handle.wait()
assert_close(output, output_fp8, rtol=0.1, atol=0.1) assert_close(output, output_fp8, rtol=0.1, atol=0.1)
@parameterize("shape", [(8, 8, 16)]) @parameterize("shape", [(8, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("dtype", [torch.bfloat16, torch.float16])
def check_all2all_uneven(shape, dtype): @parameterize("async_op", [True, False])
def check_all2all_uneven(shape, dtype, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
input_split_sizes = [3, 3, 1, 1] input_split_sizes = [3, 3, 1, 1]
if dist.get_rank() in [0, 1]: if dist.get_rank() in [0, 1]:
@ -33,22 +38,25 @@ def check_all2all_uneven(shape, dtype):
output_shape[0] = sum(output_split_sizes) output_shape[0] = sum(output_split_sizes)
output = torch.empty(output_shape, device=x.device, dtype=x.dtype) output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype) output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
dist.all_to_all_single( origin_hanle = dist.all_to_all_single(
output, output,
x, x,
output_split_sizes=output_split_sizes, output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes, input_split_sizes=input_split_sizes,
group=_get_default_group(), group=_get_default_group(),
async_op=False, async_op=async_op,
) )
all_to_all_single_fp8( fp8_handle = all_to_all_single_fp8(
output_fp8, output_fp8,
x, x,
output_split_sizes=output_split_sizes, output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes, input_split_sizes=input_split_sizes,
group=_get_default_group(), group=_get_default_group(),
async_op=False, async_op=async_op,
) )
if async_op:
origin_hanle.wait()
fp8_handle.wait()
assert_close(output, output_fp8, rtol=0.1, atol=0.1) assert_close(output, output_fp8, rtol=0.1, atol=0.1)

@ -12,7 +12,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)]) @parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("dtype", [torch.bfloat16, torch.float16])
def check_4gpu(shape, dtype): @parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, async_op):
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
@ -22,7 +23,9 @@ def check_4gpu(shape, dtype):
flat_padded_x = F.pad(flat_padded_x, (0, pad_size)) flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
output = torch.empty_like(flat_padded_x) output = torch.empty_like(flat_padded_x)
chunk = flat_padded_x.chunk(world_size)[rank].clone() chunk = flat_padded_x.chunk(world_size)[rank].clone()
all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group()) handle = all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group(), async_op=async_op)
if async_op:
handle.wait()
assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1) assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1)

@ -22,15 +22,22 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
) )
@parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("fp8_format", ["e4m3", "e5m2"]) @parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, dtype, fp8_format): @parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, fp8_format, async_op):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
x_fp8 = x.clone() x_fp8 = x.clone()
dist.all_reduce(x) origin_handle = dist.all_reduce(x, async_op=async_op)
all_reduce_fp8(x_fp8, fp8_format=fp8_format) fp8_handle = all_reduce_fp8(x_fp8, fp8_format=fp8_format, async_op=async_op)
if async_op:
origin_handle.wait()
fp8_handle.wait()
assert_close(x, x_fp8, rtol=0.1, atol=0.1) assert_close(x, x_fp8, rtol=0.1, atol=0.1)
dist.all_reduce(x, op=dist.ReduceOp.AVG) origin_handle = dist.all_reduce(x, op=dist.ReduceOp.AVG, async_op=async_op)
all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format) fp8_handle = all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format, async_op=async_op)
if async_op:
origin_handle.wait()
fp8_handle.wait()
assert_close(x, x_fp8, rtol=0.1, atol=0.1) assert_close(x, x_fp8, rtol=0.1, atol=0.1)

@ -24,13 +24,17 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
) )
@parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"]) @parameterize("fp8_format", ["e4m3", "e5m2"])
def check_4gpu(shape, dtype, fp8_format): @parameterize("async_op", [True, False])
def check_4gpu(shape, dtype, fp8_format, async_op):
world_size = dist.get_world_size() world_size = dist.get_world_size()
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output_list = [torch.empty_like(x) for _ in range(world_size)] output_list = [torch.empty_like(x) for _ in range(world_size)]
output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)] output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format) fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op)
dist.all_gather(output_list, x, group=_get_default_group()) origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op)
if async_op:
fp8_handle.wait()
origin_hanle.wait()
assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1) assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1)

Loading…
Cancel
Save