mirror of https://github.com/hpcaitech/ColossalAI
[MOE] add unitest for MOE experts layout, gradient handler and kernel (#469)
parent
1559c0df41
commit
7544347145
|
@ -7,12 +7,9 @@ from .initializer_pipeline import Initializer_Pipeline
|
|||
from .initializer_sequence import Initializer_Sequence
|
||||
from .initializer_tensor import Initializer_Tensor
|
||||
from .initializer_model import Initializer_Model
|
||||
from .initializer_moe import Initializer_Moe
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
|
||||
__all__ = [
|
||||
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline',
|
||||
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D',
|
||||
'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model',
|
||||
'Initializer_Moe'
|
||||
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D',
|
||||
'Initializer_2D', 'Initializer_3D', 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model'
|
||||
]
|
||||
|
|
|
@ -1,119 +0,0 @@
|
|||
import torch.distributed as dist
|
||||
|
||||
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||
from colossalai.global_variables import moe_env
|
||||
from .process_group_initializer import ProcessGroupInitializer
|
||||
from ..parallel_mode import ParallelMode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_Moemodel(ProcessGroupInitializer):
|
||||
"""Model parallel initialization for MoE system.
|
||||
|
||||
:param moe_moel: Size of moe model parallel
|
||||
:param moe_data: Size of moe data parallel
|
||||
:param args: Args used in base class
|
||||
:param kwargs: Kwargs used in base class
|
||||
|
||||
:type moe_model: int
|
||||
:type moe_data: int
|
||||
"""
|
||||
def __init__(self, moe_model, moe_data, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.moe_model = moe_model
|
||||
self.moe_data = moe_data
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize model parallel groups in moe parallel environment,
|
||||
and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: MoE model parallelism's information
|
||||
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.MOE_MODEL
|
||||
|
||||
for i in range(self.moe_data):
|
||||
ranks = [i * self.moe_model + j for j in range(self.moe_model)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_Moedata(ProcessGroupInitializer):
|
||||
"""Data parallel initialization for MoE system.
|
||||
|
||||
:param moe_moel: Size of moe model parallel
|
||||
:param moe_data: Size of moe data parallel
|
||||
:param args: Args used in base class
|
||||
:param kwargs: Kwargs used in base class
|
||||
|
||||
:type moe_model: int
|
||||
:type moe_data: int
|
||||
"""
|
||||
def __init__(self, moe_model, moe_data, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.moe_model = moe_model
|
||||
self.moe_data = moe_data
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize data parallel groups in moe parallel environment,
|
||||
and assign local_ranks and groups to each gpu.
|
||||
|
||||
:return: MoE data parallelism's information
|
||||
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.MOE_DATA
|
||||
|
||||
for i in range(self.moe_model):
|
||||
ranks = [i + j * self.moe_model for j in range(self.moe_data)]
|
||||
group = dist.new_group(ranks)
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
class Initializer_Moe(ProcessGroupInitializer):
|
||||
"""Serves as the single entry point to MoE parallel initialization.
|
||||
|
||||
:param args: Args used to initialize ProcessGroupInitializer
|
||||
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.moe_model = moe_env.model_parallel_size
|
||||
self.moe_data = moe_env.data_parallel_size
|
||||
self.model_initializer = Initializer_Moemodel(
|
||||
self.moe_model, self.moe_data, *args, **kwargs)
|
||||
self.data_initializer = Initializer_Moedata(
|
||||
self.moe_model, self.moe_data, *args, **kwargs)
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initializes MoE parallel communication groups.
|
||||
|
||||
:return: MoE parallelism's information
|
||||
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||
"""
|
||||
parallel_setting = [self.model_initializer.init_dist_group(),
|
||||
self.data_initializer.init_dist_group()]
|
||||
return parallel_setting
|
|
@ -2,7 +2,6 @@ from typing import Optional
|
|||
|
||||
|
||||
class TensorParallelEnv(object):
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
@ -33,7 +32,7 @@ class TensorParallelEnv(object):
|
|||
self.depth_3d = depth_3d
|
||||
self.input_group_3d = input_group_3d
|
||||
self.weight_group_3d = weight_group_3d
|
||||
self.output_group_3d = output_group_3d
|
||||
self.output_group_3d = output_group_3d
|
||||
|
||||
def save(self):
|
||||
return dict(mode=self.mode,
|
||||
|
@ -48,43 +47,4 @@ class TensorParallelEnv(object):
|
|||
output_group_3d=self.output_group_3d)
|
||||
|
||||
|
||||
class MoeEnv:
|
||||
"""Moe enviroment variables.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.data_parallel_size = None
|
||||
self.model_parallel_size = None
|
||||
self.aux_loss = None
|
||||
self.enable_cuda = True
|
||||
|
||||
def setup(self, moe_model_size):
|
||||
from .core import global_context as gpc
|
||||
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError("Moe is not compatible with tensor or pipeline parallel")
|
||||
|
||||
assert gpc.data_parallel_size % moe_model_size == 0, \
|
||||
"The size of data parallel needs to be divided by moe model parallel size"
|
||||
|
||||
self.data_parallel_size = gpc.data_parallel_size // moe_model_size
|
||||
self.model_parallel_size = moe_model_size
|
||||
|
||||
def is_initialized(self):
|
||||
return self.model_parallel_size is not None
|
||||
|
||||
def set_cuda_false(self):
|
||||
self.enable_cuda = False
|
||||
|
||||
def reset_loss(self):
|
||||
self.aux_loss = 0
|
||||
|
||||
def add_loss(self, loss):
|
||||
self.aux_loss += loss
|
||||
|
||||
def get_loss(self):
|
||||
return self.aux_loss
|
||||
|
||||
|
||||
tensor_parallel_env = TensorParallelEnv()
|
||||
|
||||
moe_env = MoeEnv()
|
||||
|
|
|
@ -19,14 +19,14 @@ from colossalai.amp import AMP_TYPE, convert_to_amp
|
|||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.core import global_context as gpc, MOE_CONTEXT
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
||||
sync_model_param)
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from colossalai.zero import convert_to_zero_v2
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
|
@ -299,9 +299,11 @@ def initialize(model: nn.Module,
|
|||
if isinstance(optimizer, Callable):
|
||||
optimizer = optimizer(model.parameters())
|
||||
|
||||
if not moe_env.is_initialized() and not use_zero:
|
||||
if not use_zero:
|
||||
if is_using_sequence():
|
||||
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
||||
elif MOE_CONTEXT.is_initialized:
|
||||
sync_moe_model_param(model)
|
||||
elif is_using_ddp():
|
||||
sync_model_param(model, ParallelMode.DATA)
|
||||
else:
|
||||
|
@ -354,7 +356,7 @@ def initialize(model: nn.Module,
|
|||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and moe_env.is_initialized():
|
||||
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
|
||||
gradient_handler_cfg = [dict(type='MoeGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
|
|
|
@ -41,21 +41,26 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
set_grad_none (bool, optional): whether set grad to None when zero_grad()
|
||||
method is called. (default: True)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
.. _Adam: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, bias_correction=True,
|
||||
betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True,
|
||||
weight_decay=0., amsgrad=False, set_grad_none=True):
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
adam_w_mode=True,
|
||||
weight_decay=0.,
|
||||
amsgrad=False,
|
||||
set_grad_none=True):
|
||||
|
||||
if amsgrad:
|
||||
raise RuntimeError(
|
||||
'FusedAdam does not support the AMSGrad variant.')
|
||||
defaults = dict(lr=lr, bias_correction=bias_correction,
|
||||
betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
|
||||
defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(FusedAdam, self).__init__(params, defaults)
|
||||
self.adam_w_mode = 1 if adam_w_mode else 0
|
||||
self.set_grad_none = set_grad_none
|
||||
|
@ -86,7 +91,8 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
"""
|
||||
if any(p is not None for p in [grads, output_params, scale, grad_norms]):
|
||||
raise RuntimeError(
|
||||
'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.')
|
||||
'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.'
|
||||
)
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
@ -135,28 +141,12 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
raise RuntimeError('FusedAdam only support fp16 and fp32.')
|
||||
|
||||
if (len(g_16) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_adam,
|
||||
self._dummy_overflow_buf,
|
||||
[g_16, p_16, m_16, v_16],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['step'],
|
||||
self.adam_w_mode,
|
||||
bias_correction,
|
||||
group['weight_decay'])
|
||||
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
|
||||
bias_correction, group['weight_decay'])
|
||||
if (len(g_32) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_adam,
|
||||
self._dummy_overflow_buf,
|
||||
[g_32, p_32, m_32, v_32],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['step'],
|
||||
self.adam_w_mode,
|
||||
bias_correction,
|
||||
group['weight_decay'])
|
||||
multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
|
||||
bias_correction, group['weight_decay'])
|
||||
|
||||
return loss
|
||||
|
|
|
@ -46,22 +46,32 @@ class FusedLAMB(torch.optim.Optimizer):
|
|||
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
|
||||
weight decay parameter (default: False)
|
||||
|
||||
.. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes:
|
||||
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
|
||||
https://arxiv.org/abs/1904.00962
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, bias_correction=True,
|
||||
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
|
||||
amsgrad=False, adam_w_mode=True,
|
||||
grad_averaging=True, set_grad_none=True,
|
||||
max_grad_norm=1.0, use_nvlamb=False):
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-6,
|
||||
weight_decay=0.01,
|
||||
amsgrad=False,
|
||||
adam_w_mode=True,
|
||||
grad_averaging=True,
|
||||
set_grad_none=True,
|
||||
max_grad_norm=1.0,
|
||||
use_nvlamb=False):
|
||||
if amsgrad:
|
||||
raise RuntimeError(
|
||||
'FusedLAMB does not support the AMSGrad variant.')
|
||||
defaults = dict(lr=lr, bias_correction=bias_correction,
|
||||
betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
|
||||
defaults = dict(lr=lr,
|
||||
bias_correction=bias_correction,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(FusedLAMB, self).__init__(params, defaults)
|
||||
|
@ -69,8 +79,9 @@ class FusedLAMB(torch.optim.Optimizer):
|
|||
import colossal_C
|
||||
self.multi_tensor_l2norm = colossal_C.multi_tensor_l2norm
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.tensor(
|
||||
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
|
||||
self._dummy_overflow_buf = torch.tensor([0],
|
||||
dtype=torch.int,
|
||||
device=self.param_groups[0]["params"][0].device)
|
||||
self.multi_tensor_lamb = colossal_C.multi_tensor_lamb
|
||||
else:
|
||||
raise RuntimeError('FusedLAMB requires cuda extensions')
|
||||
|
@ -112,23 +123,16 @@ class FusedLAMB(torch.optim.Optimizer):
|
|||
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
|
||||
|
||||
device = self.param_groups[0]["params"][0].device
|
||||
g_norm_32, g_norm_16 = torch.zeros(
|
||||
1, device=device), torch.zeros(1, device=device)
|
||||
g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
|
||||
# compute grad norm for two lists
|
||||
if len(g_all_32) > 0:
|
||||
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
|
||||
self._dummy_overflow_buf,
|
||||
[g_all_32], False)[0]
|
||||
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_32], False)[0]
|
||||
if len(g_all_16) > 0:
|
||||
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
|
||||
self._dummy_overflow_buf,
|
||||
[g_all_16], False)[0]
|
||||
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False)[0]
|
||||
|
||||
# blend two grad norms to get global grad norm
|
||||
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
|
||||
self._dummy_overflow_buf,
|
||||
[[g_norm_32, g_norm_16]],
|
||||
False)[0]
|
||||
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf,
|
||||
[[g_norm_32, g_norm_16]], False)[0]
|
||||
max_grad_norm = self.defaults['max_grad_norm']
|
||||
|
||||
for group in self.param_groups:
|
||||
|
@ -176,36 +180,14 @@ class FusedLAMB(torch.optim.Optimizer):
|
|||
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
|
||||
|
||||
if (len(g_16) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_lamb,
|
||||
self._dummy_overflow_buf,
|
||||
[g_16, p_16, m_16, v_16],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['step'],
|
||||
bias_correction,
|
||||
group['weight_decay'],
|
||||
grad_averaging,
|
||||
self.adam_w_mode,
|
||||
global_grad_norm,
|
||||
max_grad_norm,
|
||||
self.use_nvlamb)
|
||||
multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction,
|
||||
group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm,
|
||||
max_grad_norm, self.use_nvlamb)
|
||||
if (len(g_32) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_lamb,
|
||||
self._dummy_overflow_buf,
|
||||
[g_32, p_32, m_32, v_32],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['step'],
|
||||
bias_correction,
|
||||
group['weight_decay'],
|
||||
grad_averaging,
|
||||
self.adam_w_mode,
|
||||
global_grad_norm,
|
||||
max_grad_norm,
|
||||
self.use_nvlamb)
|
||||
multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32],
|
||||
group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction,
|
||||
group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm,
|
||||
max_grad_norm, self.use_nvlamb)
|
||||
|
||||
return loss
|
||||
|
|
|
@ -8,7 +8,6 @@ import torch
|
|||
from torch._six import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
try:
|
||||
import colossal_C
|
||||
except:
|
||||
|
@ -17,11 +16,9 @@ except:
|
|||
from contextlib import contextmanager
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS,
|
||||
TENSOR_PARALLEL_ATTRIBUTES)
|
||||
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
|
@ -116,7 +113,10 @@ def is_model_parallel_parameter(p):
|
|||
|
||||
|
||||
def is_moe_parallel_parameter(p):
|
||||
return hasattr(p, 'moe_param') and moe_env.data_parallel_size > 1
|
||||
# FIXME(HHC): clip_grad need to changed to adapted for MoE
|
||||
# This return value must set to False, otherwise it will raise
|
||||
# an error in training
|
||||
return False
|
||||
|
||||
|
||||
def _calc_l2_norm(grads):
|
||||
|
@ -127,7 +127,7 @@ def _calc_l2_norm(grads):
|
|||
colossal_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False # no per-parameter norm
|
||||
False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
|
@ -139,11 +139,13 @@ def _calc_lp(grads, norm_type):
|
|||
norm += grad_norm**norm_type
|
||||
return norm
|
||||
|
||||
|
||||
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
if torch.is_tensor(norm) and norm.device.type != 'cuda':
|
||||
norm = norm.to(torch.cuda.current_device())
|
||||
return norm
|
||||
|
||||
|
||||
# ======== Gradient Clipping =========
|
||||
|
||||
|
||||
|
@ -212,7 +214,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|||
else:
|
||||
tensor_parallel_grads = []
|
||||
no_tensor_parallel_grads = []
|
||||
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
|
||||
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
|
||||
zero_sharded_grads = []
|
||||
for p in params:
|
||||
if is_model_parallel_parameter(p):
|
||||
|
@ -226,13 +228,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|||
no_tensor_parallel_grads.append(p.grad.data)
|
||||
|
||||
if norm_type == 2.0 and enable_cuda_kernels:
|
||||
tensor_parallel_norm = _calc_l2_norm(
|
||||
tensor_parallel_grads) ** norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(
|
||||
no_tensor_parallel_grads) ** norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(
|
||||
moe_parallel_grads) ** norm_type
|
||||
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type
|
||||
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
|
||||
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
|
||||
moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
|
||||
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
|
||||
else:
|
||||
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
|
||||
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
|
||||
|
@ -259,10 +258,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|||
no_tensor_parallel_norm += zero_sharded_norm
|
||||
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
dist.all_reduce(total_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm ** (1.0 / norm_type)
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
total_norm = total_norm**(1.0 / norm_type)
|
||||
if torch.is_tensor(total_norm):
|
||||
total_norm = total_norm.item()
|
||||
|
||||
|
@ -272,10 +269,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|||
if enable_cuda_kernels:
|
||||
grads = [p.grad.detach() for p in params]
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale,
|
||||
dummy_overflow_buf,
|
||||
[grads, grads],
|
||||
clip_coeff)
|
||||
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
|
||||
else:
|
||||
for p in params:
|
||||
p.grad.detach().mul_(clip_coeff)
|
||||
|
|
|
@ -7,7 +7,7 @@ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
|
|||
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator
|
||||
from .util import moe_sa_args, moe_mlp_args
|
||||
from ..helper import TransformerLayer
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
|
@ -136,7 +136,7 @@ class Widenet(nn.Module):
|
|||
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
|
||||
|
||||
def forward(self, x):
|
||||
moe_env.reset_loss()
|
||||
MOE_CONTEXT.reset_loss()
|
||||
x = self.widenet(x)
|
||||
x = torch.mean(x, dim=1)
|
||||
x = self.linear(x)
|
||||
|
@ -201,7 +201,7 @@ class ViTMoE(nn.Module):
|
|||
self.vitmoe = nn.Sequential(embedding, embed_dropout, *blocks, norm)
|
||||
|
||||
def forward(self, x):
|
||||
moe_env.reset_loss()
|
||||
MOE_CONTEXT.reset_loss()
|
||||
x = self.vitmoe(x)
|
||||
x = torch.mean(x, dim=1)
|
||||
x = self.linear(x)
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
from functools import partial
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
import colossalai
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from colossalai.engine.gradient_handler import MoeGradientHandler
|
||||
from colossalai.testing import assert_equal_in_group
|
||||
|
||||
BATCH_SIZE = 4
|
||||
DIM = 16
|
||||
CONFIG = dict()
|
||||
|
||||
|
||||
def run_test(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
expert_module = nn.Linear
|
||||
expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device())
|
||||
|
||||
MOE_CONTEXT.setup(42) # MOE initialization
|
||||
noisy_func = UniformNoiseGenerator()
|
||||
router = Top1Router(noisy_func=noisy_func)
|
||||
num_experts_list = [1, 2, 4]
|
||||
layer_list = []
|
||||
for num_experts in num_experts_list:
|
||||
exp = Experts(expert_module, num_experts, **expert_factor)
|
||||
moe_layer = MoeLayer(DIM, num_experts, router, exp)
|
||||
layer_list.append(moe_layer)
|
||||
|
||||
model = nn.Sequential(*layer_list)
|
||||
model = model.to(get_current_device())
|
||||
sync_moe_model_param(model)
|
||||
|
||||
dist_dict = MOE_CONTEXT.information
|
||||
assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group)
|
||||
assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group)
|
||||
# MoE model synchronization passed
|
||||
|
||||
grad_handler = MoeGradientHandler(model, 0)
|
||||
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.manual_seed(78 + rank)
|
||||
data = torch.randn(BATCH_SIZE, DIM, device=get_current_device())
|
||||
grad = torch.randn_like(data)
|
||||
|
||||
MOE_CONTEXT.reset_loss()
|
||||
outputs = model(data)
|
||||
outputs.backward(grad)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)
|
||||
assert_equal_in_group(layer_list[0].experts.experts[0].bias.grad, dist_dict[1].dp_group)
|
||||
|
||||
assert_equal_in_group(layer_list[1].experts.experts[0].weight.grad, dist_dict[2].dp_group)
|
||||
assert_equal_in_group(layer_list[1].experts.experts[0].bias.grad, dist_dict[2].dp_group)
|
||||
# MoE grad handler test passed
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_grad_handler():
|
||||
world_size = 4
|
||||
run_func = partial(run_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_grad_handler()
|
|
@ -7,57 +7,64 @@ import colossalai
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top2Router, MoeLayer, Experts
|
||||
from colossalai.context.random import moe_set_seed
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
|
||||
BATCH_SIZE = 32
|
||||
BATCH_SIZE = 16
|
||||
NUM_EXPERTS = 4
|
||||
CONFIG = dict(parallel=dict(moe=dict(size=4)))
|
||||
CONFIG = dict()
|
||||
|
||||
|
||||
def check_equal(A, B, atol=1e-06):
|
||||
assert torch.allclose(A, B, rtol=0, atol=atol) is True
|
||||
def check_equal(tensor_a, tensor_b, atol=1e-06):
|
||||
assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True
|
||||
|
||||
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
moe_set_seed(42)
|
||||
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, router=Top2Router):
|
||||
# Here we do not need TF32, since it brings absolute error on results
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
torch.manual_seed(rs + local_rank)
|
||||
moe_env.reset_loss()
|
||||
|
||||
MOE_CONTEXT.setup(42) # MOE environment initialization
|
||||
MOE_CONTEXT.reset_loss()
|
||||
torch.manual_seed(rs + local_rank) # set each process has different random seed
|
||||
|
||||
# get randomized data
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
|
||||
router = Top2Router(1)
|
||||
expert = Experts(nn.Identity, 4)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, expert)
|
||||
expert_module = nn.Linear
|
||||
expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
|
||||
expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
layer.cuda_mode = False
|
||||
|
||||
# use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine
|
||||
layer.use_kernel = False
|
||||
old_out = layer(tokens)
|
||||
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
old_out.backward(grad)
|
||||
old_out.backward(grad) # get gradient
|
||||
|
||||
# save all results
|
||||
o_tk_grad = tokens.grad.data.clone()
|
||||
o_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
# reset all gradients
|
||||
tokens.grad.zero_()
|
||||
layer.gate.weight.grad.zero_()
|
||||
|
||||
layer.cuda_mode = True
|
||||
new_out = layer(tokens)
|
||||
layer.use_kernel = True
|
||||
new_out = layer(tokens) # get ouputs through colossal kernel
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
else:
|
||||
check_equal(old_out, new_out, 1e-2)
|
||||
# forward function passed
|
||||
|
||||
new_out.backward(grad)
|
||||
new_out.backward(grad) # get new type gradient
|
||||
n_tk_grad = tokens.grad.data.clone()
|
||||
n_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
|
@ -65,28 +72,31 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
|||
check_equal(o_tk_grad, n_tk_grad)
|
||||
else:
|
||||
check_equal(o_tk_grad, o_tk_grad, 1e-2)
|
||||
# tokens gradient is correct
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_gt_grad, n_gt_grad, 5e-05)
|
||||
else:
|
||||
check_equal(o_gt_grad, n_gt_grad, 2e-01)
|
||||
# bias gradient is correct
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="MoE refactoring has not finished yet")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("rs", [131])
|
||||
@pytest.mark.parametrize("hidden_size", [32, 144])
|
||||
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
|
||||
def test_moe_top2(rs, hidden_size, data_type):
|
||||
@pytest.mark.parametrize("router", [Top1Router, Top2Router])
|
||||
def test_moe_kernel(rs, hidden_size, data_type, router):
|
||||
world_size = 4
|
||||
run_func = partial(run_routing,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
rs=rs,
|
||||
hidden_size=hidden_size,
|
||||
data_type=data_type)
|
||||
data_type=data_type,
|
||||
router=router)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_top2(2, 256, torch.float16)
|
||||
test_moe_kernel(2, 256, torch.float16, Top2Router)
|
|
@ -0,0 +1,70 @@
|
|||
from functools import partial
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
import colossalai
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Experts
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from colossalai.testing import assert_equal_in_group
|
||||
|
||||
D_MODEL = 4
|
||||
D_FF = 8
|
||||
CONFIG = dict()
|
||||
|
||||
|
||||
def run_test(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
expert_module = nn.Linear
|
||||
expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device())
|
||||
|
||||
MOE_CONTEXT.setup(42) # MOE environment initialization
|
||||
exp0 = Experts(expert_module, 1, **expert_factor)
|
||||
exp1 = Experts(expert_module, 2, **expert_factor)
|
||||
exp2 = Experts(expert_module, 4, **expert_factor)
|
||||
exp3 = Experts(expert_module, 8, **expert_factor)
|
||||
|
||||
assert exp0.num_local_experts == 1
|
||||
assert exp1.num_local_experts == 1
|
||||
assert exp2.num_local_experts == 1
|
||||
assert exp3.num_local_experts == 2
|
||||
# experts deployment passed
|
||||
|
||||
dist_dict = MOE_CONTEXT.information
|
||||
rank = dist.get_rank()
|
||||
|
||||
assert len(dist_dict) == 3
|
||||
assert dist.get_rank(dist_dict[4].ep_group) == rank
|
||||
assert dist.get_rank(dist_dict[2].ep_group) == rank % 2
|
||||
assert dist.get_rank(dist_dict[1].ep_group) == 0
|
||||
|
||||
assert dist.get_rank(dist_dict[4].dp_group) == 0
|
||||
assert dist.get_rank(dist_dict[2].dp_group) == rank // 2
|
||||
assert dist.get_rank(dist_dict[1].dp_group) == rank
|
||||
# group creation passed
|
||||
|
||||
model = nn.ModuleList([exp0, exp1, exp2, exp3])
|
||||
model = model.to(get_current_device())
|
||||
sync_moe_model_param(model)
|
||||
|
||||
assert_equal_in_group(exp0.experts[0].weight.data, dist_dict[1].dp_group)
|
||||
assert_equal_in_group(exp0.experts[0].bias.data, dist_dict[1].dp_group)
|
||||
# MOE experts layout success when ep_size = 1
|
||||
|
||||
assert_equal_in_group(exp1.experts[0].weight.data, dist_dict[2].dp_group)
|
||||
assert_equal_in_group(exp1.experts[0].bias.data, dist_dict[2].dp_group)
|
||||
# MOE experts layout success when ep_size = 2
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_moe_initialization():
|
||||
world_size = 4
|
||||
run_func = partial(run_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_initialization()
|
|
@ -1,97 +0,0 @@
|
|||
from functools import partial
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
import colossalai
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top1Router, MoeLayer
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
BATCH_SIZE = 32
|
||||
NUM_EXPERTS = 4
|
||||
CONFIG = dict(parallel=dict(moe=dict(size=4)))
|
||||
|
||||
|
||||
def check_equal(A, B, atol=1e-06):
|
||||
assert torch.allclose(A, B, rtol=0, atol=atol) is True
|
||||
|
||||
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# torch.set_printoptions(precision=30)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
torch.manual_seed(rs + local_rank)
|
||||
moe_env.reset_loss()
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
# print(f"tokens:\n{tokens}")
|
||||
router = Top1Router(1)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
layer.cuda_mode = False
|
||||
|
||||
old_out = layer(tokens)
|
||||
# print(f"old output:\n{old_out}")
|
||||
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
old_out.backward(grad)
|
||||
|
||||
o_tk_grad = tokens.grad.data.clone()
|
||||
o_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
tokens.grad.zero_()
|
||||
layer.gate.weight.grad.zero_()
|
||||
|
||||
layer.cuda_mode = True
|
||||
new_out = layer(tokens)
|
||||
|
||||
# print(torch.max(torch.abs(old_out - new_out)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
else:
|
||||
check_equal(old_out, new_out, 1e-2)
|
||||
# print(f"forward functions passed")
|
||||
|
||||
# print(f"new output:\n{new_out}")
|
||||
new_out.backward(grad)
|
||||
n_tk_grad = tokens.grad.data.clone()
|
||||
n_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_tk_grad, n_tk_grad)
|
||||
else:
|
||||
check_equal(o_tk_grad, o_tk_grad, 1e-2)
|
||||
# print(f"tokens gradient passed")
|
||||
|
||||
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_gt_grad, n_gt_grad, 5e-05)
|
||||
else:
|
||||
check_equal(o_gt_grad, n_gt_grad, 2e-01)
|
||||
# print(f"linear weight gradient passed")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Should be activated for detailed tests")
|
||||
@pytest.mark.parametrize("rs", [2, 42, 60])
|
||||
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
|
||||
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
|
||||
def test_moe_top2(rs, hidden_size, data_type):
|
||||
world_size = 4
|
||||
run_func = partial(run_routing,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
rs=rs,
|
||||
hidden_size=hidden_size,
|
||||
data_type=data_type)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_top2(60, 512, torch.float16)
|
|
@ -1,97 +0,0 @@
|
|||
from functools import partial
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
import colossalai
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top2Router, MoeLayer
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
BATCH_SIZE = 32
|
||||
NUM_EXPERTS = 4
|
||||
CONFIG = dict(parallel=dict(moe=dict(size=4)))
|
||||
|
||||
|
||||
def check_equal(A, B, atol=1e-06):
|
||||
assert torch.allclose(A, B, rtol=0, atol=atol) is True
|
||||
|
||||
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# torch.set_printoptions(precision=30)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
torch.manual_seed(rs + local_rank)
|
||||
moe_env.reset_loss()
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
# print(f"tokens:\n{tokens}")
|
||||
router = Top2Router(1)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
layer.cuda_mode = False
|
||||
|
||||
old_out = layer(tokens)
|
||||
# print(f"old output:\n{old_out}")
|
||||
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
old_out.backward(grad)
|
||||
|
||||
o_tk_grad = tokens.grad.data.clone()
|
||||
o_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
tokens.grad.zero_()
|
||||
layer.gate.weight.grad.zero_()
|
||||
|
||||
layer.cuda_mode = True
|
||||
new_out = layer(tokens)
|
||||
|
||||
# print(torch.max(torch.abs(old_out - new_out)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
else:
|
||||
check_equal(old_out, new_out, 1e-2)
|
||||
# print(f"forward functions passed")
|
||||
|
||||
# print(f"new output:\n{new_out}")
|
||||
new_out.backward(grad)
|
||||
n_tk_grad = tokens.grad.data.clone()
|
||||
n_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_tk_grad, n_tk_grad)
|
||||
else:
|
||||
check_equal(o_tk_grad, o_tk_grad, 1e-2)
|
||||
# print(f"tokens gradient passed")
|
||||
|
||||
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_gt_grad, n_gt_grad, 5e-05)
|
||||
else:
|
||||
check_equal(o_gt_grad, n_gt_grad, 2e-01)
|
||||
# print(f"linear weight gradient passed")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Should be activated for detailed tests")
|
||||
@pytest.mark.parametrize("rs", [2, 42, 60])
|
||||
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
|
||||
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
|
||||
def test_moe_top2(rs, hidden_size, data_type):
|
||||
world_size = 4
|
||||
run_func = partial(run_routing,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
rs=rs,
|
||||
hidden_size=hidden_size,
|
||||
data_type=data_type)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_top2(2, 256, torch.float16)
|
Loading…
Reference in New Issue