|
|
|
@ -69,7 +69,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
world_size = dist.get_world_size() |
|
|
|
|
rank = dist.get_rank() |
|
|
|
|
dist.get_rank() |
|
|
|
|
input_type = tensor.dtype |
|
|
|
|
input_shape = tensor.shape |
|
|
|
|
input_device = tensor.device |
|
|
|
@ -102,4 +102,4 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
|
|
|
|
for i in range(world_size): |
|
|
|
|
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] |
|
|
|
|
tensor_out = torch.cat(tensor_list, dim=0) |
|
|
|
|
tensor.data = tensor_out.view(input_shape).to(input_type) |
|
|
|
|
tensor.data = tensor_out.view(input_shape).to(input_type) |
|
|
|
|