From 1cb7bdad3b7e120855f493545522591ec4f1f49a Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 12 Apr 2022 09:44:40 +0800 Subject: [PATCH] [util] fixed communication API depth with PyTorch 1.9 (#721) --- colossalai/communication/collective.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py index 7eeac241e..e0db6ca6c 100644 --- a/colossalai/communication/collective.py +++ b/colossalai/communication/collective.py @@ -211,7 +211,7 @@ def reduce(tensor: Tensor, def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None): r"""Modified from `torch.distributed.scatter_object_list ` to fix issues """ - if dist._rank_not_in_group(group): + if dist.distributed_c10d._rank_not_in_group(group): return if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1): @@ -220,7 +220,7 @@ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, s # set tensor device to cuda if backend is nccl device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu") - my_rank = dist.get_rank() # use global rank + my_rank = dist.get_rank() # use global rank if my_rank == src: tensor_list, tensor_sizes = zip( *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list])