mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish communication/p2p_v2.py code style (#2303)
parent
b965585d05
commit
116e3d0b8f
|
@ -1,14 +1,14 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import List, Tuple, Union, Any
|
|
||||||
import pickle
|
|
||||||
import io
|
import io
|
||||||
|
import pickle
|
||||||
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import distributed_c10d as c10d
|
|
||||||
from torch.distributed import ProcessGroupNCCL
|
from torch.distributed import ProcessGroupNCCL
|
||||||
|
from torch.distributed import distributed_c10d as c10d
|
||||||
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
@ -23,7 +23,7 @@ def init_process_group():
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
None
|
None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
@ -40,7 +40,7 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou
|
||||||
second_rank (int): second rank in the pair
|
second_rank (int): second rank in the pair
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks
|
:class:`ProcessGroupNCCL`: the handle of the group consisting of the given two ranks
|
||||||
"""
|
"""
|
||||||
if len(_pg_manager) == 0:
|
if len(_pg_manager) == 0:
|
||||||
init_process_group()
|
init_process_group()
|
||||||
|
@ -51,8 +51,8 @@ def _acquire_pair_group_handle(first_rank: int, second_rank: int) -> ProcessGrou
|
||||||
|
|
||||||
|
|
||||||
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
|
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
|
||||||
"""transform tensor to object with unpickle.
|
"""transform tensor to object with unpickle.
|
||||||
Info of the device in bytes stream will be modified into current device before unpickling
|
Info of the device in bytes stream will be modified into current device before unpickling
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor (:class:`torch.tensor`): tensor to be unpickled
|
tensor (:class:`torch.tensor`): tensor to be unpickled
|
||||||
|
@ -78,9 +78,9 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||||
def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=None):
|
def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=None):
|
||||||
"""This is a modified version of the broadcast_object_list in torch.distribution
|
"""This is a modified version of the broadcast_object_list in torch.distribution
|
||||||
The only difference is that object will be move to correct device after unpickled.
|
The only difference is that object will be move to correct device after unpickled.
|
||||||
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
||||||
be updated with data sent from rank src.
|
be updated with data sent from rank src.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
object_list (List[Any]): list of object to broadcast
|
object_list (List[Any]): list of object to broadcast
|
||||||
src (int): source rank to broadcast
|
src (int): source rank to broadcast
|
||||||
|
@ -182,7 +182,7 @@ def _recv_object(src: int) -> Any:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src (int): source rank of data. local rank will receive data from src rank.
|
src (int): source rank of data. local rank will receive data from src rank.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Any: Object received from src.
|
Any: Object received from src.
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue