Browse Source

[NFC] polish communication/p2p_v2.py code style (#2303)

pull/2317/head
ver217 2 years ago committed by Frank Lee
parent
commit
116e3d0b8f
  1. 20
      colossalai/communication/p2p_v2.py

20
colossalai/communication/p2p_v2.py

@ -1,14 +1,14 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Tuple, Union, Any
import pickle
import io
import pickle
from typing import Any, List, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed import distributed_c10d as c10d
from torch.distributed import ProcessGroupNCCL
from torch.distributed import distributed_c10d as c10d
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@ -23,7 +23,7 @@ def init_process_group():
Args:
None
Returns:
None
"""
@ -40,7 +40,7 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou
second_rank (int): second rank in the pair
Returns:
:class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks
:class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks
"""
if len(_pg_manager) == 0:
init_process_group()
@ -51,8 +51,8 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
"""transform tensor to object with unpickle.
Info of the device in bytes stream will be modified into current device before unpickling
"""transform tensor to object with unpickle.
Info of the device in bytes stream will be modified into current device before unpickling
Args:
tensor (:class:`torch.tensor`): tensor to be unpickled
@ -78,9 +78,9 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=None):
"""This is a modified version of the broadcast_object_list in torch.distribution
The only difference is that object will be move to correct device after unpickled.
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
be updated with data sent from rank src.
Args:
object_list (List[Any]): list of object to broadcast
src (int): source rank to broadcast
@ -182,7 +182,7 @@ def _recv_object(src: int) -> Any:
Args:
src (int): source rank of data. local rank will receive data from src rank.
Returns:
Any: Object received from src.
"""

Loading…
Cancel
Save