from typing import List import torch import torch.distributed as dist from torch.distributed import ProcessGroup from .moe_info import MoeParallelInfo def is_moe_tensor(tensor: torch.Tensor) -> bool: """ Check whether the given tensor is a moe tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: bool: Whether the given tensor is a moe tensor. """ return hasattr(tensor, "moe_info") def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None: """ Set moe info for the given tensor. Args: tensor (torch.Tensor): The tensor to be set. moe_info (dict): The moe info to be set. """ tensor.__setattr__("moe_info", moe_info) def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo: """ Get moe info for the given tensor. Args: ep_size (int): The expert parallel size. dp_size (int): The data parallel size. pp_size (int): The pipeline parallel size. ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if False. Returns: dict: The moe info of the given tensor. """ return MoeParallelInfo(ep_inside, ep_size, dp_size, pp_size) def get_ep_group(tensor: torch.Tensor) -> ProcessGroup: """ Get the expert parallel group of the given tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: torch.distributed.ProcessGroup: The expert parallel group of the given tensor. """ return tensor.moe_info.ep_group def get_ep_size(tensor: torch.Tensor) -> int: """ Get the expert parallel size of the given tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: int: The expert parallel size of the given tensor. """ return tensor.moe_info.ep_size def get_dp_size(tensor: torch.Tensor) -> int: """ Get the data parallel size of the given tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: int: The data parallel size of the given tensor. """ return tensor.moe_info.dp_size def get_dp_group(tensor: torch.Tensor) -> ProcessGroup: """ Get the data parallel group of the given tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: torch.distributed.ProcessGroup: The data parallel group of the given tensor. """ return tensor.moe_info.dp_group def get_ep_rank(tensor: torch.Tensor) -> int: """ Get the expert parallel rank of the given tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: int: The expert parallel rank of the given tensor. """ return dist.get_rank(get_ep_group(tensor)) def get_dp_rank(tensor: torch.Tensor) -> int: """ Get the data parallel rank of the given tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: int: The data parallel rank of the given tensor. """ return dist.get_rank(get_dp_group(tensor)) def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]: """ Get the expert parallel group ranks of the given tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: int: The expert parallel group ranks of the given tensor. """ return tensor.moe_info.ep_group_ranks def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]: """ Get the data parallel group ranks of the given tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: int: The data parallel group ranks of the given tensor. """ return tensor.moe_info.dp_group_ranks