mirror of https://github.com/hpcaitech/ColossalAI
fix (#5976)
parent
ccabcf6485
commit
7739629b9d
|
@ -376,28 +376,6 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
|
||||||
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
||||||
|
|
||||||
|
|
||||||
def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"):
|
|
||||||
|
|
||||||
world_size = dist.get_world_size(group)
|
|
||||||
|
|
||||||
per_slice_len = input_tensor.size(0) // world_size
|
|
||||||
input_type = input_tensor.dtype
|
|
||||||
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
|
|
||||||
fp8_type = ret.dtype
|
|
||||||
input_tensor = ret.view(torch.uint8)
|
|
||||||
tensor = torch.empty_like(input_tensor)
|
|
||||||
scale_list = [torch.empty_like(scale) for _ in range(world_size)]
|
|
||||||
dist.all_to_all_single(tensor, input_tensor, group=group)
|
|
||||||
dist.all_gather(scale_list, scale, group=group)
|
|
||||||
cast_tensor_list = []
|
|
||||||
|
|
||||||
for i in range(world_size):
|
|
||||||
output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type)
|
|
||||||
output_part = cast_from_fp8(output_part, scale_list[i], input_type)
|
|
||||||
cast_tensor_list.append(output_part)
|
|
||||||
output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0))
|
|
||||||
|
|
||||||
|
|
||||||
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
|
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
|
||||||
|
|
||||||
world_size = dist.get_world_size(group)
|
world_size = dist.get_world_size(group)
|
||||||
|
|
Loading…
Reference in New Issue