import torch from colossalai.utils import get_current_device from colossalai.context.moe_context import MOE_CONTEXT from .experts import FFNExperts, TPExperts class ForceFP32Parameter(torch.nn.Parameter): def half(self, memory_format=None): return self class NormalNoiseGenerator: """Generates a random noisy mask for logtis tensor. All noise is generated from a normal distribution (0, 1 / E^2), where E = the number of experts. :param num_experts: The number of experts :type num_experts: int """ def __init__(self, num_experts: int): self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), scale=torch.tensor(1.0 / num_experts**2, device=get_current_device())).rsample def __call__(self, inputs: torch.Tensor): noisy = self.normal(inputs.shape) return inputs + noisy class UniformNoiseGenerator: """Generates a random noisy mask for logtis tensor. copied from mesh tensorflow: Multiply values by a random number between 1-epsilon and 1+epsilon. Makes models more resilient to rounding errors introduced by bfloat16. This seems particularly important for logits. :param eps: Epsilon in generator :type eps: float """ def __init__(self, eps: float = 1e-2): self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), high=torch.tensor(1.0 + eps, device=get_current_device())).rsample def __call__(self, inputs: torch.Tensor): noisy = self.uniform(inputs.shape) return inputs * noisy def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): mep_size = MOE_CONTEXT.max_ep_size if num_experts % mep_size == 0 or mep_size % num_experts == 0: return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) elif d_ff % mep_size == 0: return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) else: raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")