mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
3.7 KiB
152 lines
3.7 KiB
from typing import List |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch.distributed import ProcessGroup |
|
|
|
from .moe_info import MoeParallelInfo |
|
|
|
|
|
def is_moe_tensor(tensor: torch.Tensor) -> bool: |
|
""" |
|
Check whether the given tensor is a moe tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to be checked. |
|
|
|
Returns: |
|
bool: Whether the given tensor is a moe tensor. |
|
""" |
|
return hasattr(tensor, "moe_info") |
|
|
|
|
|
def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None: |
|
""" |
|
Set moe info for the given tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to be set. |
|
moe_info (dict): The moe info to be set. |
|
|
|
""" |
|
tensor.__setattr__("moe_info", moe_info) |
|
|
|
|
|
def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo: |
|
""" |
|
Get moe info for the given tensor. |
|
|
|
Args: |
|
ep_size (int): The expert parallel size. |
|
dp_size (int): The data parallel size. |
|
pp_size (int): The pipeline parallel size. |
|
ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. |
|
|
|
Returns: |
|
dict: The moe info of the given tensor. |
|
""" |
|
return MoeParallelInfo(ep_inside, ep_size, dp_size, pp_size) |
|
|
|
|
|
def get_ep_group(tensor: torch.Tensor) -> ProcessGroup: |
|
""" |
|
Get the expert parallel group of the given tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to be checked. |
|
|
|
Returns: |
|
torch.distributed.ProcessGroup: The expert parallel group of the given tensor. |
|
""" |
|
return tensor.moe_info.ep_group |
|
|
|
|
|
def get_ep_size(tensor: torch.Tensor) -> int: |
|
""" |
|
Get the expert parallel size of the given tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to be checked. |
|
|
|
Returns: |
|
int: The expert parallel size of the given tensor. |
|
""" |
|
return tensor.moe_info.ep_size |
|
|
|
|
|
def get_dp_size(tensor: torch.Tensor) -> int: |
|
""" |
|
Get the data parallel size of the given tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to be checked. |
|
|
|
Returns: |
|
int: The data parallel size of the given tensor. |
|
""" |
|
return tensor.moe_info.dp_size |
|
|
|
|
|
def get_dp_group(tensor: torch.Tensor) -> ProcessGroup: |
|
""" |
|
Get the data parallel group of the given tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to be checked. |
|
|
|
Returns: |
|
torch.distributed.ProcessGroup: The data parallel group of the given tensor. |
|
""" |
|
return tensor.moe_info.dp_group |
|
|
|
|
|
def get_ep_rank(tensor: torch.Tensor) -> int: |
|
""" |
|
Get the expert parallel rank of the given tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to be checked. |
|
|
|
Returns: |
|
int: The expert parallel rank of the given tensor. |
|
""" |
|
return dist.get_rank(get_ep_group(tensor)) |
|
|
|
|
|
def get_dp_rank(tensor: torch.Tensor) -> int: |
|
""" |
|
Get the data parallel rank of the given tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to be checked. |
|
|
|
Returns: |
|
int: The data parallel rank of the given tensor. |
|
""" |
|
return dist.get_rank(get_dp_group(tensor)) |
|
|
|
|
|
def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]: |
|
""" |
|
Get the expert parallel group ranks of the given tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to be checked. |
|
|
|
Returns: |
|
int: The expert parallel group ranks of the given tensor. |
|
""" |
|
return tensor.moe_info.ep_group_ranks |
|
|
|
|
|
def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]: |
|
""" |
|
Get the data parallel group ranks of the given tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to be checked. |
|
|
|
Returns: |
|
int: The data parallel group ranks of the given tensor. |
|
""" |
|
return tensor.moe_info.dp_group_ranks
|
|
|