Browse Source

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/5885/head
pre-commit-ci[bot] 5 months ago
parent
commit
e17f835df7
  1. 4
      colossalai/quantization/fp8.py

4
colossalai/quantization/fp8.py

@ -69,7 +69,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
""" """
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank() dist.get_rank()
input_type = tensor.dtype input_type = tensor.dtype
input_shape = tensor.shape input_shape = tensor.shape
input_device = tensor.device input_device = tensor.device
@ -102,4 +102,4 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
for i in range(world_size): for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_out = torch.cat(tensor_list, dim=0) tensor_out = torch.cat(tensor_list, dim=0)
tensor.data = tensor_out.view(input_shape).to(input_type) tensor.data = tensor_out.view(input_shape).to(input_type)

Loading…
Cancel
Save