mirror of https://github.com/hpcaitech/ColossalAI
[fp8] support asynchronous FP8 communication (#5997)
* fix * fix * fix * support async all2all * support async op for all gather * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6002/head
parent
0978080a69
commit
597b206001
|
@ -10,6 +10,18 @@ from torch.distributed import ReduceOp
|
||||||
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0")
|
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0")
|
||||||
|
|
||||||
|
|
||||||
|
class Handle:
|
||||||
|
def __init__(self, handles=[], remain_ops=None) -> None:
|
||||||
|
self.handles = handles
|
||||||
|
self.remain_ops = remain_ops
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
for handle in self.handles:
|
||||||
|
handle.wait()
|
||||||
|
if self.remain_ops:
|
||||||
|
self.remain_ops()
|
||||||
|
|
||||||
|
|
||||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||||
|
@ -68,7 +80,9 @@ def cast_from_fp8(
|
||||||
return ret.to(ret_type)
|
return ret.to(ret_type)
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None:
|
def all_reduce_fp8(
|
||||||
|
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
||||||
|
) -> Optional[Handle]:
|
||||||
r"""
|
r"""
|
||||||
This is an in-place operation for compressed all_reduce using fp8.
|
This is an in-place operation for compressed all_reduce using fp8.
|
||||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||||
|
@ -105,6 +119,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
|
||||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||||
dist.all_gather(scale_list, scale, group=group)
|
dist.all_gather(scale_list, scale, group=group)
|
||||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||||
|
|
||||||
for scale, out in zip(scale_list, output_chunks):
|
for scale, out in zip(scale_list, output_chunks):
|
||||||
out = out.view(fp8_type)
|
out = out.view(fp8_type)
|
||||||
summed_out += cast_from_fp8(out, scale, input_type)
|
summed_out += cast_from_fp8(out, scale, input_type)
|
||||||
|
@ -113,19 +128,28 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
|
||||||
summed_out.div_(world_size)
|
summed_out.div_(world_size)
|
||||||
|
|
||||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||||
dist.all_gather(scale_list, scale, group=group)
|
gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||||
|
|
||||||
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
|
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
|
||||||
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
|
gather_tensor_handle = dist.all_gather(
|
||||||
|
tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op
|
||||||
|
)
|
||||||
|
|
||||||
|
def cat_op():
|
||||||
for i in range(world_size):
|
for i in range(world_size):
|
||||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device)
|
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||||
out = torch.cat(tensor_list, dim=0)
|
out = torch.cat(tensor_list, dim=0)
|
||||||
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
||||||
|
|
||||||
|
if async_op:
|
||||||
|
return Handle([gather_scale_handle, gather_tensor_handle], cat_op)
|
||||||
|
else:
|
||||||
|
cat_op()
|
||||||
|
|
||||||
|
|
||||||
def all_to_all_single_fp8(
|
def all_to_all_single_fp8(
|
||||||
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||||
) -> None:
|
) -> Optional[Handle]:
|
||||||
r"""
|
r"""
|
||||||
This is an in-place operation for compressed all_reduce using fp8.
|
This is an in-place operation for compressed all_reduce using fp8.
|
||||||
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
|
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
|
||||||
|
@ -163,9 +187,11 @@ def all_to_all_single_fp8(
|
||||||
else:
|
else:
|
||||||
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||||
|
|
||||||
dist.all_to_all(output_chunks, input_chunks, group=group)
|
chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op)
|
||||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||||
dist.all_gather(scale_list, scale, group=group)
|
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||||
|
|
||||||
|
def cast_op():
|
||||||
cast_output_chunk = [
|
cast_output_chunk = [
|
||||||
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
|
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
|
||||||
]
|
]
|
||||||
|
@ -178,6 +204,11 @@ def all_to_all_single_fp8(
|
||||||
outputs_shape = input_shape
|
outputs_shape = input_shape
|
||||||
output.data = tensor_out.view(outputs_shape).to(input_type)
|
output.data = tensor_out.view(outputs_shape).to(input_type)
|
||||||
|
|
||||||
|
if async_op:
|
||||||
|
return Handle([chunk_handle, scale_hanle], cast_op)
|
||||||
|
else:
|
||||||
|
cast_op()
|
||||||
|
|
||||||
|
|
||||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -250,7 +281,9 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||||
del inp["dtype"]
|
del inp["dtype"]
|
||||||
|
|
||||||
|
|
||||||
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
|
def reduce_scatter_fp8(
|
||||||
|
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
||||||
|
) -> Optional[Handle]:
|
||||||
r"""
|
r"""
|
||||||
This is an in-place operation for compressed reduce_scatter using fp8.
|
This is an in-place operation for compressed reduce_scatter using fp8.
|
||||||
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
||||||
|
@ -277,15 +310,21 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
|
||||||
cast_input_list.append(ret)
|
cast_input_list.append(ret)
|
||||||
output_chunks.append(torch.empty_like(ret))
|
output_chunks.append(torch.empty_like(ret))
|
||||||
output_scale_list.append(torch.empty_like(scale))
|
output_scale_list.append(torch.empty_like(scale))
|
||||||
dist.all_to_all(output_chunks, cast_input_list, group=group)
|
chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op)
|
||||||
dist.all_to_all(output_scale_list, scale_list, group=group)
|
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
|
||||||
|
|
||||||
|
def cast_op():
|
||||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||||
for scale, out in zip(output_scale_list, output_chunks):
|
for scale, out in zip(output_scale_list, output_chunks):
|
||||||
out = out.view(fp8_type)
|
out = out.view(fp8_type)
|
||||||
summed_out += cast_from_fp8(out, scale, input_type)
|
summed_out += cast_from_fp8(out, scale, input_type)
|
||||||
output.data = summed_out
|
output.data = summed_out
|
||||||
|
|
||||||
|
if async_op:
|
||||||
|
return Handle([chunk_handle, scale_handle], cast_op)
|
||||||
|
else:
|
||||||
|
cast_op()
|
||||||
|
|
||||||
|
|
||||||
def fp8_compress_ddp_grad_comm_hook_async(
|
def fp8_compress_ddp_grad_comm_hook_async(
|
||||||
process_group: dist.ProcessGroup,
|
process_group: dist.ProcessGroup,
|
||||||
|
@ -500,7 +539,8 @@ def all_gather_into_tensor_flat_fp8(
|
||||||
output_shape: torch.Size,
|
output_shape: torch.Size,
|
||||||
group: dist.ProcessGroup,
|
group: dist.ProcessGroup,
|
||||||
fp8_format: str = "e4m3",
|
fp8_format: str = "e4m3",
|
||||||
):
|
async_op: bool = False,
|
||||||
|
) -> Optional[Handle]:
|
||||||
"""all gather into tensor in fp8 format
|
"""all gather into tensor in fp8 format
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -547,15 +587,25 @@ def all_gather_into_tensor_flat_fp8(
|
||||||
scale = fp8_max / per_tensor_max
|
scale = fp8_max / per_tensor_max
|
||||||
fp8_input = (scale * input_tensor.float()).to(fp8_type)
|
fp8_input = (scale * input_tensor.float()).to(fp8_type)
|
||||||
scale_inv = 1.0 / scale
|
scale_inv = 1.0 / scale
|
||||||
|
|
||||||
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
|
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
|
||||||
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
|
tensor_handle = dist.all_gather_into_tensor(
|
||||||
|
buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group, async_op=async_op
|
||||||
|
)
|
||||||
|
|
||||||
|
def cast_op():
|
||||||
numel = output_shape.numel()
|
numel = output_shape.numel()
|
||||||
valid_buffer = buffer[:numel].reshape(output_shape)
|
valid_buffer = buffer[:numel].reshape(output_shape)
|
||||||
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
|
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
|
||||||
output_tensor[:numel].copy_(valid_buffer.view(-1))
|
output_tensor[:numel].copy_(valid_buffer.view(-1))
|
||||||
|
|
||||||
|
if async_op:
|
||||||
|
return Handle([tensor_handle], cast_op)
|
||||||
|
else:
|
||||||
|
cast_op()
|
||||||
|
|
||||||
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
|
|
||||||
|
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
||||||
|
|
||||||
world_size = dist.get_world_size(group)
|
world_size = dist.get_world_size(group)
|
||||||
|
|
||||||
|
@ -573,17 +623,23 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
|
||||||
|
|
||||||
output_scale_list = [torch.empty_like(x) for x in scale_list]
|
output_scale_list = [torch.empty_like(x) for x in scale_list]
|
||||||
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
|
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
|
||||||
dist.all_to_all(output_tensor_list, tensor_list, group=group)
|
tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op)
|
||||||
dist.all_to_all(output_scale_list, scale_list, group=group)
|
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
|
||||||
|
|
||||||
|
def cast_op():
|
||||||
for i in range(world_size):
|
for i in range(world_size):
|
||||||
scale = output_scale_list[i]
|
scale = output_scale_list[i]
|
||||||
tensor = output_tensor_list[i]
|
tensor = output_tensor_list[i]
|
||||||
tensor = tensor.view(fp8_type)
|
tensor = tensor.view(fp8_type)
|
||||||
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
||||||
|
|
||||||
|
if async_op:
|
||||||
|
return Handle([tensor_hanle, scale_handle], cast_op)
|
||||||
|
else:
|
||||||
|
cast_op()
|
||||||
|
|
||||||
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
|
|
||||||
|
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||||
|
|
||||||
world_size = dist.get_world_size(group)
|
world_size = dist.get_world_size(group)
|
||||||
|
|
||||||
|
@ -593,14 +649,20 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
|
||||||
input_ = ret.view(torch.uint8)
|
input_ = ret.view(torch.uint8)
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
|
||||||
dist.all_gather(tensor_list, input_, group=group)
|
chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op)
|
||||||
dist.all_gather(scale_list, scale, group=group)
|
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||||
|
|
||||||
|
def cast_op():
|
||||||
for i in range(world_size):
|
for i in range(world_size):
|
||||||
output = tensor_list[i].view(fp8_type)
|
output = tensor_list[i].view(fp8_type)
|
||||||
scale = scale_list[i]
|
scale = scale_list[i]
|
||||||
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
||||||
|
|
||||||
|
if async_op:
|
||||||
|
return Handle([chunk_handle, scale_hanle], cast_op)
|
||||||
|
else:
|
||||||
|
cast_op()
|
||||||
|
|
||||||
|
|
||||||
class _LinearFp8(torch.autograd.Function):
|
class _LinearFp8(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -10,19 +10,24 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
|
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
|
||||||
@parameterize("dtype", [torch.bfloat16])
|
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||||
def check_all2all(shape, dtype):
|
@parameterize("async_op", [True, False])
|
||||||
|
def check_all2all(shape, dtype, async_op):
|
||||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||||
output = torch.empty_like(x)
|
output = torch.empty_like(x)
|
||||||
output_fp8 = torch.empty_like(x)
|
output_fp8 = torch.empty_like(x)
|
||||||
dist.all_to_all_single(output, x, group=_get_default_group(), async_op=False)
|
origin_hanle = dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op)
|
||||||
all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=False)
|
fp8_handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op)
|
||||||
|
if async_op:
|
||||||
|
origin_hanle.wait()
|
||||||
|
fp8_handle.wait()
|
||||||
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
@parameterize("shape", [(8, 8, 16)])
|
@parameterize("shape", [(8, 8, 16)])
|
||||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||||
def check_all2all_uneven(shape, dtype):
|
@parameterize("async_op", [True, False])
|
||||||
|
def check_all2all_uneven(shape, dtype, async_op):
|
||||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||||
input_split_sizes = [3, 3, 1, 1]
|
input_split_sizes = [3, 3, 1, 1]
|
||||||
if dist.get_rank() in [0, 1]:
|
if dist.get_rank() in [0, 1]:
|
||||||
|
@ -33,22 +38,25 @@ def check_all2all_uneven(shape, dtype):
|
||||||
output_shape[0] = sum(output_split_sizes)
|
output_shape[0] = sum(output_split_sizes)
|
||||||
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
||||||
output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
||||||
dist.all_to_all_single(
|
origin_hanle = dist.all_to_all_single(
|
||||||
output,
|
output,
|
||||||
x,
|
x,
|
||||||
output_split_sizes=output_split_sizes,
|
output_split_sizes=output_split_sizes,
|
||||||
input_split_sizes=input_split_sizes,
|
input_split_sizes=input_split_sizes,
|
||||||
group=_get_default_group(),
|
group=_get_default_group(),
|
||||||
async_op=False,
|
async_op=async_op,
|
||||||
)
|
)
|
||||||
all_to_all_single_fp8(
|
fp8_handle = all_to_all_single_fp8(
|
||||||
output_fp8,
|
output_fp8,
|
||||||
x,
|
x,
|
||||||
output_split_sizes=output_split_sizes,
|
output_split_sizes=output_split_sizes,
|
||||||
input_split_sizes=input_split_sizes,
|
input_split_sizes=input_split_sizes,
|
||||||
group=_get_default_group(),
|
group=_get_default_group(),
|
||||||
async_op=False,
|
async_op=async_op,
|
||||||
)
|
)
|
||||||
|
if async_op:
|
||||||
|
origin_hanle.wait()
|
||||||
|
fp8_handle.wait()
|
||||||
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
|
@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
|
||||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||||
def check_4gpu(shape, dtype):
|
@parameterize("async_op", [True, False])
|
||||||
|
def check_4gpu(shape, dtype, async_op):
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||||
|
@ -22,7 +23,9 @@ def check_4gpu(shape, dtype):
|
||||||
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
||||||
output = torch.empty_like(flat_padded_x)
|
output = torch.empty_like(flat_padded_x)
|
||||||
chunk = flat_padded_x.chunk(world_size)[rank].clone()
|
chunk = flat_padded_x.chunk(world_size)[rank].clone()
|
||||||
all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group())
|
handle = all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group(), async_op=async_op)
|
||||||
|
if async_op:
|
||||||
|
handle.wait()
|
||||||
assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1)
|
assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,15 +22,22 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
)
|
)
|
||||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||||
def check_4gpu(shape, dtype, fp8_format):
|
@parameterize("async_op", [True, False])
|
||||||
|
def check_4gpu(shape, dtype, fp8_format, async_op):
|
||||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||||
x_fp8 = x.clone()
|
x_fp8 = x.clone()
|
||||||
dist.all_reduce(x)
|
origin_handle = dist.all_reduce(x, async_op=async_op)
|
||||||
all_reduce_fp8(x_fp8, fp8_format=fp8_format)
|
fp8_handle = all_reduce_fp8(x_fp8, fp8_format=fp8_format, async_op=async_op)
|
||||||
|
if async_op:
|
||||||
|
origin_handle.wait()
|
||||||
|
fp8_handle.wait()
|
||||||
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
|
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
|
||||||
|
|
||||||
dist.all_reduce(x, op=dist.ReduceOp.AVG)
|
origin_handle = dist.all_reduce(x, op=dist.ReduceOp.AVG, async_op=async_op)
|
||||||
all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format)
|
fp8_handle = all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format, async_op=async_op)
|
||||||
|
if async_op:
|
||||||
|
origin_handle.wait()
|
||||||
|
fp8_handle.wait()
|
||||||
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
|
assert_close(x, x_fp8, rtol=0.1, atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,13 +24,17 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
)
|
)
|
||||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||||
def check_4gpu(shape, dtype, fp8_format):
|
@parameterize("async_op", [True, False])
|
||||||
|
def check_4gpu(shape, dtype, fp8_format, async_op):
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||||
output_list = [torch.empty_like(x) for _ in range(world_size)]
|
output_list = [torch.empty_like(x) for _ in range(world_size)]
|
||||||
output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
|
output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)]
|
||||||
gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format)
|
fp8_handle = gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format, async_op=async_op)
|
||||||
dist.all_gather(output_list, x, group=_get_default_group())
|
origin_hanle = dist.all_gather(output_list, x, group=_get_default_group(), async_op=async_op)
|
||||||
|
if async_op:
|
||||||
|
fp8_handle.wait()
|
||||||
|
origin_hanle.wait()
|
||||||
assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1)
|
assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue