diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index d405de2de..c880cd4aa 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -69,7 +69,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: """ world_size = dist.get_world_size() - rank = dist.get_rank() + dist.get_rank() input_type = tensor.dtype input_shape = tensor.shape input_device = tensor.device @@ -102,4 +102,4 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: for i in range(world_size): tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] tensor_out = torch.cat(tensor_list, dim=0) - tensor.data = tensor_out.view(input_shape).to(input_type) \ No newline at end of file + tensor.data = tensor_out.view(input_shape).to(input_type)