pull/5981/head
flybird11111 2024-08-07 18:58:39 +08:00 committed by GitHub
parent ccabcf6485
commit 7739629b9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 0 additions and 22 deletions

View File

@ -376,28 +376,6 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"):
world_size = dist.get_world_size(group)
per_slice_len = input_tensor.size(0) // world_size
input_type = input_tensor.dtype
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
fp8_type = ret.dtype
input_tensor = ret.view(torch.uint8)
tensor = torch.empty_like(input_tensor)
scale_list = [torch.empty_like(scale) for _ in range(world_size)]
dist.all_to_all_single(tensor, input_tensor, group=group)
dist.all_gather(scale_list, scale, group=group)
cast_tensor_list = []
for i in range(world_size):
output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type)
output_part = cast_from_fp8(output_part, scale_list[i], input_type)
cast_tensor_list.append(output_part)
output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0))
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
world_size = dist.get_world_size(group) world_size = dist.get_world_size(group)