mirror of https://github.com/hpcaitech/ColossalAI
fix (#5976)
parent
ccabcf6485
commit
7739629b9d
|
@ -376,28 +376,6 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
|
|||
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
||||
|
||||
|
||||
def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"):
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
per_slice_len = input_tensor.size(0) // world_size
|
||||
input_type = input_tensor.dtype
|
||||
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
|
||||
fp8_type = ret.dtype
|
||||
input_tensor = ret.view(torch.uint8)
|
||||
tensor = torch.empty_like(input_tensor)
|
||||
scale_list = [torch.empty_like(scale) for _ in range(world_size)]
|
||||
dist.all_to_all_single(tensor, input_tensor, group=group)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
cast_tensor_list = []
|
||||
|
||||
for i in range(world_size):
|
||||
output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type)
|
||||
output_part = cast_from_fp8(output_part, scale_list[i], input_type)
|
||||
cast_tensor_list.append(output_part)
|
||||
output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0))
|
||||
|
||||
|
||||
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
|
Loading…
Reference in New Issue