diff --git a/colossalai/communication/p2p_v2.py b/colossalai/communication/p2p_v2.py index 0b575e7db..4223f78d5 100644 --- a/colossalai/communication/p2p_v2.py +++ b/colossalai/communication/p2p_v2.py @@ -1,14 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import List, Tuple, Union, Any -import pickle import io +import pickle +from typing import Any, List, Tuple, Union import torch import torch.distributed as dist -from torch.distributed import distributed_c10d as c10d from torch.distributed import ProcessGroupNCCL +from torch.distributed import distributed_c10d as c10d from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc @@ -23,7 +23,7 @@ def init_process_group(): Args: None - + Returns: None """ @@ -40,7 +40,7 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou second_rank (int): second rank in the pair Returns: - :class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks + :class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks """ if len(_pg_manager) == 0: init_process_group() @@ -51,8 +51,8 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: - """transform tensor to object with unpickle. - Info of the device in bytes stream will be modified into current device before unpickling + """transform tensor to object with unpickle. + Info of the device in bytes stream will be modified into current device before unpickling Args: tensor (:class:`torch.tensor`): tensor to be unpickled @@ -78,9 +78,9 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=None): """This is a modified version of the broadcast_object_list in torch.distribution The only difference is that object will be move to correct device after unpickled. - If local_rank = src, then object list will be sent to rank src. Otherwise, object list will + If local_rank = src, then object list will be sent to rank src. Otherwise, object list will be updated with data sent from rank src. - + Args: object_list (List[Any]): list of object to broadcast src (int): source rank to broadcast @@ -182,7 +182,7 @@ def _recv_object(src: int) -> Any: Args: src (int): source rank of data. local rank will receive data from src rank. - + Returns: Any: Object received from src. """