2022-02-18 12:42:31 +00:00
|
|
|
import torch
|
2022-03-31 10:34:11 +00:00
|
|
|
import torch.nn.functional as F
|
2022-02-18 12:42:31 +00:00
|
|
|
from colossalai.utils import get_current_device
|
2022-03-23 10:03:39 +00:00
|
|
|
from colossalai.context.moe_context import MOE_CONTEXT
|
2022-02-27 14:28:39 +00:00
|
|
|
from .experts import FFNExperts, TPExperts
|
2022-02-18 12:42:31 +00:00
|
|
|
|
|
|
|
|
2022-03-22 02:50:20 +00:00
|
|
|
class ForceFP32Parameter(torch.nn.Parameter):
|
|
|
|
|
|
|
|
def half(self, memory_format=None):
|
2022-04-24 05:08:48 +00:00
|
|
|
return self.data.clone()
|
2022-03-22 02:50:20 +00:00
|
|
|
|
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
class NormalNoiseGenerator:
|
2023-04-26 03:38:43 +00:00
|
|
|
"""Generates a random noisy mask for logits tensor.
|
2022-02-18 12:42:31 +00:00
|
|
|
|
2022-04-01 08:15:36 +00:00
|
|
|
All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
|
|
|
|
`E = the number of experts`.
|
2022-02-18 12:42:31 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
num_experts (int): The number of experts.
|
2022-02-18 12:42:31 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, num_experts: int):
|
2022-02-27 14:28:39 +00:00
|
|
|
self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()),
|
|
|
|
scale=torch.tensor(1.0 / num_experts**2,
|
|
|
|
device=get_current_device())).rsample
|
2022-02-18 12:42:31 +00:00
|
|
|
|
|
|
|
def __call__(self, inputs: torch.Tensor):
|
|
|
|
noisy = self.normal(inputs.shape)
|
|
|
|
return inputs + noisy
|
|
|
|
|
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
class UniformNoiseGenerator:
|
2023-04-26 03:38:43 +00:00
|
|
|
"""Generates a random noisy mask for logits tensor.
|
2022-03-16 08:47:44 +00:00
|
|
|
copied from mesh tensorflow:
|
2022-04-01 08:15:36 +00:00
|
|
|
Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
|
2022-03-16 08:47:44 +00:00
|
|
|
Makes models more resilient to rounding errors introduced by bfloat16.
|
|
|
|
This seems particularly important for logits.
|
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
eps (float, optional): Epsilon in generator, defaults 1e-2.
|
2022-03-16 08:47:44 +00:00
|
|
|
"""
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
def __init__(self, eps: float = 1e-2):
|
2022-03-16 08:47:44 +00:00
|
|
|
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()),
|
|
|
|
high=torch.tensor(1.0 + eps,
|
|
|
|
device=get_current_device())).rsample
|
|
|
|
|
|
|
|
def __call__(self, inputs: torch.Tensor):
|
|
|
|
noisy = self.uniform(inputs.shape)
|
|
|
|
return inputs * noisy
|
|
|
|
|
|
|
|
|
2022-03-31 10:34:11 +00:00
|
|
|
def autocast_softmax(logit: torch.Tensor, dim: int):
|
|
|
|
if logit.dtype != torch.float32:
|
|
|
|
logit = logit.float()
|
|
|
|
return F.softmax(logit, dim=dim)
|
|
|
|
|
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
2022-03-19 07:36:25 +00:00
|
|
|
mep_size = MOE_CONTEXT.max_ep_size
|
|
|
|
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
|
2022-02-27 14:28:39 +00:00
|
|
|
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
2022-03-19 07:36:25 +00:00
|
|
|
elif d_ff % mep_size == 0:
|
2022-02-27 14:28:39 +00:00
|
|
|
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
|
|
|
else:
|
2022-03-19 07:36:25 +00:00
|
|
|
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")
|