[util] fixed communication API depth with PyTorch 1.9 (#721)

pull/729/head
Frank Lee 2022-04-12 09:44:40 +08:00 committed by GitHub
parent 2412429d54
commit 1cb7bdad3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -211,7 +211,7 @@ def reduce(tensor: Tensor,
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None):
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
"""
if dist._rank_not_in_group(group):
if dist.distributed_c10d._rank_not_in_group(group):
return
if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1):
@ -220,7 +220,7 @@ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, s
# set tensor device to cuda if backend is nccl
device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu")
my_rank = dist.get_rank() # use global rank
my_rank = dist.get_rank() # use global rank
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list])