[MOE] add unitest for MOE experts layout, gradient handler and kernel (#469)

pull/474/head
HELSON 2022-03-21 13:35:04 +08:00 committed by GitHub
parent 1559c0df41
commit 7544347145
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 263 additions and 499 deletions

View File

@ -7,12 +7,9 @@ from .initializer_pipeline import Initializer_Pipeline
from .initializer_sequence import Initializer_Sequence from .initializer_sequence import Initializer_Sequence
from .initializer_tensor import Initializer_Tensor from .initializer_tensor import Initializer_Tensor
from .initializer_model import Initializer_Model from .initializer_model import Initializer_Model
from .initializer_moe import Initializer_Moe
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
__all__ = [ __all__ = [
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D',
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D', 'Initializer_2D', 'Initializer_3D', 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model'
'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model',
'Initializer_Moe'
] ]

View File

@ -1,119 +0,0 @@
import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from colossalai.global_variables import moe_env
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Moemodel(ProcessGroupInitializer):
"""Model parallel initialization for MoE system.
:param moe_moel: Size of moe model parallel
:param moe_data: Size of moe data parallel
:param args: Args used in base class
:param kwargs: Kwargs used in base class
:type moe_model: int
:type moe_data: int
"""
def __init__(self, moe_model, moe_data, *args, **kwargs):
super().__init__(*args, **kwargs)
self.moe_model = moe_model
self.moe_data = moe_data
def init_dist_group(self):
"""Initialize model parallel groups in moe parallel environment,
and assign local_ranks and groups to each gpu.
:return: MoE model parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.MOE_MODEL
for i in range(self.moe_data):
ranks = [i * self.moe_model + j for j in range(self.moe_model)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Moedata(ProcessGroupInitializer):
"""Data parallel initialization for MoE system.
:param moe_moel: Size of moe model parallel
:param moe_data: Size of moe data parallel
:param args: Args used in base class
:param kwargs: Kwargs used in base class
:type moe_model: int
:type moe_data: int
"""
def __init__(self, moe_model, moe_data, *args, **kwargs):
super().__init__(*args, **kwargs)
self.moe_model = moe_model
self.moe_data = moe_data
def init_dist_group(self):
"""Initialize data parallel groups in moe parallel environment,
and assign local_ranks and groups to each gpu.
:return: MoE data parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.MOE_DATA
for i in range(self.moe_model):
ranks = [i + j * self.moe_model for j in range(self.moe_data)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Moe(ProcessGroupInitializer):
"""Serves as the single entry point to MoE parallel initialization.
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.moe_model = moe_env.model_parallel_size
self.moe_data = moe_env.data_parallel_size
self.model_initializer = Initializer_Moemodel(
self.moe_model, self.moe_data, *args, **kwargs)
self.data_initializer = Initializer_Moedata(
self.moe_model, self.moe_data, *args, **kwargs)
def init_dist_group(self):
"""Initializes MoE parallel communication groups.
:return: MoE parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
parallel_setting = [self.model_initializer.init_dist_group(),
self.data_initializer.init_dist_group()]
return parallel_setting

View File

@ -2,7 +2,6 @@ from typing import Optional
class TensorParallelEnv(object): class TensorParallelEnv(object):
_instance = None _instance = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
@ -48,43 +47,4 @@ class TensorParallelEnv(object):
output_group_3d=self.output_group_3d) output_group_3d=self.output_group_3d)
class MoeEnv:
"""Moe enviroment variables.
"""
def __init__(self):
self.data_parallel_size = None
self.model_parallel_size = None
self.aux_loss = None
self.enable_cuda = True
def setup(self, moe_model_size):
from .core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or pipeline parallel")
assert gpc.data_parallel_size % moe_model_size == 0, \
"The size of data parallel needs to be divided by moe model parallel size"
self.data_parallel_size = gpc.data_parallel_size // moe_model_size
self.model_parallel_size = moe_model_size
def is_initialized(self):
return self.model_parallel_size is not None
def set_cuda_false(self):
self.enable_cuda = False
def reset_loss(self):
self.aux_loss = 0
def add_loss(self, loss):
self.aux_loss += loss
def get_loss(self):
return self.aux_loss
tensor_parallel_env = TensorParallelEnv() tensor_parallel_env = TensorParallelEnv()
moe_env = MoeEnv()

View File

@ -19,14 +19,14 @@ from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc, MOE_CONTEXT
from colossalai.engine import Engine from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook from colossalai.engine.ophooks import BaseOpHook
from colossalai.global_variables import moe_env
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence, from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param) sync_model_param)
from colossalai.utils.moe import sync_moe_model_param
from colossalai.zero import convert_to_zero_v2 from colossalai.zero import convert_to_zero_v2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
@ -299,9 +299,11 @@ def initialize(model: nn.Module,
if isinstance(optimizer, Callable): if isinstance(optimizer, Callable):
optimizer = optimizer(model.parameters()) optimizer = optimizer(model.parameters())
if not moe_env.is_initialized() and not use_zero: if not use_zero:
if is_using_sequence(): if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP) sync_model_param(model, ParallelMode.SEQUENCE_DP)
elif MOE_CONTEXT.is_initialized:
sync_moe_model_param(model)
elif is_using_ddp(): elif is_using_ddp():
sync_model_param(model, ParallelMode.DATA) sync_model_param(model, ParallelMode.DATA)
else: else:
@ -354,7 +356,7 @@ def initialize(model: nn.Module,
"Training with zero is detected, ZeROGradientHandler is automatically " "Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration", "added even though not specified in the configuration",
ranks=[0]) ranks=[0])
elif is_using_ddp() and moe_env.is_initialized(): elif is_using_ddp() and MOE_CONTEXT.is_initialized:
gradient_handler_cfg = [dict(type='MoeGradientHandler')] gradient_handler_cfg = [dict(type='MoeGradientHandler')]
if verbose: if verbose:
logger.info( logger.info(

View File

@ -41,21 +41,26 @@ class FusedAdam(torch.optim.Optimizer):
set_grad_none (bool, optional): whether set grad to None when zero_grad() set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True) method is called. (default: True)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond: .. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, params, lr=1e-3, bias_correction=True, def __init__(self,
betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True, params,
weight_decay=0., amsgrad=False, set_grad_none=True): lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
adam_w_mode=True,
weight_decay=0.,
amsgrad=False,
set_grad_none=True):
if amsgrad: if amsgrad:
raise RuntimeError( raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
'FusedAdam does not support the AMSGrad variant.') defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay)
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay)
super(FusedAdam, self).__init__(params, defaults) super(FusedAdam, self).__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0 self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
@ -86,7 +91,8 @@ class FusedAdam(torch.optim.Optimizer):
""" """
if any(p is not None for p in [grads, output_params, scale, grad_norms]): if any(p is not None for p in [grads, output_params, scale, grad_norms]):
raise RuntimeError( raise RuntimeError(
'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.') 'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.'
)
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
@ -135,28 +141,12 @@ class FusedAdam(torch.optim.Optimizer):
raise RuntimeError('FusedAdam only support fp16 and fp32.') raise RuntimeError('FusedAdam only support fp16 and fp32.')
if (len(g_16) > 0): if (len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adam, multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16],
self._dummy_overflow_buf, group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
[g_16, p_16, m_16, v_16], bias_correction, group['weight_decay'])
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
if (len(g_32) > 0): if (len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_adam, multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32],
self._dummy_overflow_buf, group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
[g_32, p_32, m_32, v_32], bias_correction, group['weight_decay'])
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
return loss return loss

View File

@ -46,22 +46,32 @@ class FusedLAMB(torch.optim.Optimizer):
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False) weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes: .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962 https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond: .. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, params, lr=1e-3, bias_correction=True, def __init__(self,
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, params,
amsgrad=False, adam_w_mode=True, lr=1e-3,
grad_averaging=True, set_grad_none=True, bias_correction=True,
max_grad_norm=1.0, use_nvlamb=False): betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0.01,
amsgrad=False,
adam_w_mode=True,
grad_averaging=True,
set_grad_none=True,
max_grad_norm=1.0,
use_nvlamb=False):
if amsgrad: if amsgrad:
raise RuntimeError( raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
'FusedLAMB does not support the AMSGrad variant.') defaults = dict(lr=lr,
defaults = dict(lr=lr, bias_correction=bias_correction, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas,
eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging, grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults) super(FusedLAMB, self).__init__(params, defaults)
@ -69,8 +79,9 @@ class FusedLAMB(torch.optim.Optimizer):
import colossal_C import colossal_C
self.multi_tensor_l2norm = colossal_C.multi_tensor_l2norm self.multi_tensor_l2norm = colossal_C.multi_tensor_l2norm
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.tensor( self._dummy_overflow_buf = torch.tensor([0],
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) dtype=torch.int,
device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = colossal_C.multi_tensor_lamb self.multi_tensor_lamb = colossal_C.multi_tensor_lamb
else: else:
raise RuntimeError('FusedLAMB requires cuda extensions') raise RuntimeError('FusedLAMB requires cuda extensions')
@ -112,23 +123,16 @@ class FusedLAMB(torch.optim.Optimizer):
raise RuntimeError('FusedLAMB only support fp16 and fp32.') raise RuntimeError('FusedLAMB only support fp16 and fp32.')
device = self.param_groups[0]["params"][0].device device = self.param_groups[0]["params"][0].device
g_norm_32, g_norm_16 = torch.zeros( g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
1, device=device), torch.zeros(1, device=device)
# compute grad norm for two lists # compute grad norm for two lists
if len(g_all_32) > 0: if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_32], False)[0]
self._dummy_overflow_buf,
[g_all_32], False)[0]
if len(g_all_16) > 0: if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False)[0]
self._dummy_overflow_buf,
[g_all_16], False)[0]
# blend two grad norms to get global grad norm # blend two grad norms to get global grad norm
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf,
self._dummy_overflow_buf, [[g_norm_32, g_norm_16]], False)[0]
[[g_norm_32, g_norm_16]],
False)[0]
max_grad_norm = self.defaults['max_grad_norm'] max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups: for group in self.param_groups:
@ -176,36 +180,14 @@ class FusedLAMB(torch.optim.Optimizer):
raise RuntimeError('FusedLAMB only support fp16 and fp32.') raise RuntimeError('FusedLAMB only support fp16 and fp32.')
if (len(g_16) > 0): if (len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_lamb, multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16],
self._dummy_overflow_buf, group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction,
[g_16, p_16, m_16, v_16], group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm,
group['lr'], max_grad_norm, self.use_nvlamb)
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
global_grad_norm,
max_grad_norm,
self.use_nvlamb)
if (len(g_32) > 0): if (len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb, multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32],
self._dummy_overflow_buf, group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction,
[g_32, p_32, m_32, v_32], group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm,
group['lr'], max_grad_norm, self.use_nvlamb)
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
global_grad_norm,
max_grad_norm,
self.use_nvlamb)
return loss return loss

View File

@ -8,7 +8,6 @@ import torch
from torch._six import inf from torch._six import inf
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
try: try:
import colossal_C import colossal_C
except: except:
@ -17,11 +16,9 @@ except:
from contextlib import contextmanager from contextlib import contextmanager
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES)
TENSOR_PARALLEL_ATTRIBUTES)
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
@ -116,7 +113,10 @@ def is_model_parallel_parameter(p):
def is_moe_parallel_parameter(p): def is_moe_parallel_parameter(p):
return hasattr(p, 'moe_param') and moe_env.data_parallel_size > 1 # FIXME(HHC): clip_grad need to changed to adapted for MoE
# This return value must set to False, otherwise it will raise
# an error in training
return False
def _calc_l2_norm(grads): def _calc_l2_norm(grads):
@ -127,7 +127,7 @@ def _calc_l2_norm(grads):
colossal_C.multi_tensor_l2norm, colossal_C.multi_tensor_l2norm,
dummy_overflow_buf, dummy_overflow_buf,
[grads], [grads],
False # no per-parameter norm False # no per-parameter norm
) )
return norm return norm
@ -139,11 +139,13 @@ def _calc_lp(grads, norm_type):
norm += grad_norm**norm_type norm += grad_norm**norm_type
return norm return norm
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if torch.is_tensor(norm) and norm.device.type != 'cuda': if torch.is_tensor(norm) and norm.device.type != 'cuda':
norm = norm.to(torch.cuda.current_device()) norm = norm.to(torch.cuda.current_device())
return norm return norm
# ======== Gradient Clipping ========= # ======== Gradient Clipping =========
@ -212,7 +214,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
else: else:
tensor_parallel_grads = [] tensor_parallel_grads = []
no_tensor_parallel_grads = [] no_tensor_parallel_grads = []
moe_parallel_grads = [] # used to collect moe tensor parallel gradients moe_parallel_grads = [] # used to collect moe tensor parallel gradients
zero_sharded_grads = [] zero_sharded_grads = []
for p in params: for p in params:
if is_model_parallel_parameter(p): if is_model_parallel_parameter(p):
@ -226,13 +228,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
no_tensor_parallel_grads.append(p.grad.data) no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0 and enable_cuda_kernels: if norm_type == 2.0 and enable_cuda_kernels:
tensor_parallel_norm = _calc_l2_norm( tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
tensor_parallel_grads) ** norm_type no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
no_tensor_parallel_norm = _calc_l2_norm( moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
no_tensor_parallel_grads) ** norm_type zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
moe_parallel_norm = _calc_l2_norm(
moe_parallel_grads) ** norm_type
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type
else: else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
@ -259,10 +258,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
no_tensor_parallel_norm += zero_sharded_norm no_tensor_parallel_norm += zero_sharded_norm
total_norm = tensor_parallel_norm + no_tensor_parallel_norm total_norm = tensor_parallel_norm + no_tensor_parallel_norm
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
dist.all_reduce(total_norm, dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
op=dist.ReduceOp.SUM, total_norm = total_norm**(1.0 / norm_type)
group=gpc.get_group(ParallelMode.PIPELINE))
total_norm = total_norm ** (1.0 / norm_type)
if torch.is_tensor(total_norm): if torch.is_tensor(total_norm):
total_norm = total_norm.item() total_norm = total_norm.item()
@ -272,10 +269,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if enable_cuda_kernels: if enable_cuda_kernels:
grads = [p.grad.detach() for p in params] grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale, multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
dummy_overflow_buf,
[grads, grads],
clip_coeff)
else: else:
for p in params: for p in params:
p.grad.detach().mul_(clip_coeff) p.grad.detach().mul_(clip_coeff)

View File

@ -7,7 +7,7 @@ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator
from .util import moe_sa_args, moe_mlp_args from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer from ..helper import TransformerLayer
from colossalai.global_variables import moe_env from colossalai.core import MOE_CONTEXT
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -136,7 +136,7 @@ class Widenet(nn.Module):
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm) self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
def forward(self, x): def forward(self, x):
moe_env.reset_loss() MOE_CONTEXT.reset_loss()
x = self.widenet(x) x = self.widenet(x)
x = torch.mean(x, dim=1) x = torch.mean(x, dim=1)
x = self.linear(x) x = self.linear(x)
@ -201,7 +201,7 @@ class ViTMoE(nn.Module):
self.vitmoe = nn.Sequential(embedding, embed_dropout, *blocks, norm) self.vitmoe = nn.Sequential(embedding, embed_dropout, *blocks, norm)
def forward(self, x): def forward(self, x):
moe_env.reset_loss() MOE_CONTEXT.reset_loss()
x = self.vitmoe(x) x = self.vitmoe(x)
x = torch.mean(x, dim=1) x = torch.mean(x, dim=1)
x = self.linear(x) x = self.linear(x)

View File

@ -0,0 +1,72 @@
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts
from colossalai.core import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group
BATCH_SIZE = 4
DIM = 16
CONFIG = dict()
def run_test(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
expert_module = nn.Linear
expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device())
MOE_CONTEXT.setup(42) # MOE initialization
noisy_func = UniformNoiseGenerator()
router = Top1Router(noisy_func=noisy_func)
num_experts_list = [1, 2, 4]
layer_list = []
for num_experts in num_experts_list:
exp = Experts(expert_module, num_experts, **expert_factor)
moe_layer = MoeLayer(DIM, num_experts, router, exp)
layer_list.append(moe_layer)
model = nn.Sequential(*layer_list)
model = model.to(get_current_device())
sync_moe_model_param(model)
dist_dict = MOE_CONTEXT.information
assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group)
# MoE model synchronization passed
grad_handler = MoeGradientHandler(model, 0)
rank = dist.get_rank()
torch.cuda.manual_seed(78 + rank)
data = torch.randn(BATCH_SIZE, DIM, device=get_current_device())
grad = torch.randn_like(data)
MOE_CONTEXT.reset_loss()
outputs = model(data)
outputs.backward(grad)
grad_handler.handle_gradient()
assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[0].experts.experts[0].bias.grad, dist_dict[1].dp_group)
assert_equal_in_group(layer_list[1].experts.experts[0].weight.grad, dist_dict[2].dp_group)
assert_equal_in_group(layer_list[1].experts.experts[0].bias.grad, dist_dict[2].dp_group)
# MoE grad handler test passed
@pytest.mark.dist
def test_grad_handler():
world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_grad_handler()

View File

@ -7,57 +7,64 @@ import colossalai
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer, Experts from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.context.random import moe_set_seed from colossalai.core import MOE_CONTEXT
from colossalai.global_variables import moe_env
BATCH_SIZE = 32 BATCH_SIZE = 16
NUM_EXPERTS = 4 NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4))) CONFIG = dict()
def check_equal(A, B, atol=1e-06): def check_equal(tensor_a, tensor_b, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32): def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, router=Top2Router):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # Here we do not need TF32, since it brings absolute error on results
moe_set_seed(42)
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss() MOE_CONTEXT.setup(42) # MOE environment initialization
MOE_CONTEXT.reset_loss()
torch.manual_seed(rs + local_rank) # set each process has different random seed
# get randomized data
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
router = Top2Router(1) expert_module = nn.Linear
expert = Experts(nn.Identity, 4) expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, expert) expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
if data_type == torch.float16: if data_type == torch.float16:
layer = layer.half() layer = layer.half()
layer.cuda_mode = False
# use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine
layer.use_kernel = False
old_out = layer(tokens) old_out = layer(tokens)
ech = old_out.shape ech = old_out.shape
grad = torch.randn(ech, device=get_current_device()) grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad) old_out.backward(grad) # get gradient
# save all results
o_tk_grad = tokens.grad.data.clone() o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone() o_gt_grad = layer.gate.weight.grad.data.clone()
# reset all gradients
tokens.grad.zero_() tokens.grad.zero_()
layer.gate.weight.grad.zero_() layer.gate.weight.grad.zero_()
layer.cuda_mode = True layer.use_kernel = True
new_out = layer(tokens) new_out = layer(tokens) # get ouputs through colossal kernel
if data_type == torch.float32: if data_type == torch.float32:
check_equal(old_out, new_out) check_equal(old_out, new_out)
else: else:
check_equal(old_out, new_out, 1e-2) check_equal(old_out, new_out, 1e-2)
# forward function passed
new_out.backward(grad) new_out.backward(grad) # get new type gradient
n_tk_grad = tokens.grad.data.clone() n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone() n_gt_grad = layer.gate.weight.grad.data.clone()
@ -65,28 +72,31 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
check_equal(o_tk_grad, n_tk_grad) check_equal(o_tk_grad, n_tk_grad)
else: else:
check_equal(o_tk_grad, o_tk_grad, 1e-2) check_equal(o_tk_grad, o_tk_grad, 1e-2)
# tokens gradient is correct
if data_type == torch.float32: if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05) check_equal(o_gt_grad, n_gt_grad, 5e-05)
else: else:
check_equal(o_gt_grad, n_gt_grad, 2e-01) check_equal(o_gt_grad, n_gt_grad, 2e-01)
# bias gradient is correct
@pytest.mark.skip(reason="MoE refactoring has not finished yet")
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("rs", [131]) @pytest.mark.parametrize("rs", [131])
@pytest.mark.parametrize("hidden_size", [32, 144]) @pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) @pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type): @pytest.mark.parametrize("router", [Top1Router, Top2Router])
def test_moe_kernel(rs, hidden_size, data_type, router):
world_size = 4 world_size = 4
run_func = partial(run_routing, run_func = partial(run_routing,
world_size=world_size, world_size=world_size,
port=free_port(), port=free_port(),
rs=rs, rs=rs,
hidden_size=hidden_size, hidden_size=hidden_size,
data_type=data_type) data_type=data_type,
router=router)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_moe_top2(2, 256, torch.float16) test_moe_kernel(2, 256, torch.float16, Top2Router)

View File

@ -0,0 +1,70 @@
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts
from colossalai.core import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group
D_MODEL = 4
D_FF = 8
CONFIG = dict()
def run_test(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
expert_module = nn.Linear
expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device())
MOE_CONTEXT.setup(42) # MOE environment initialization
exp0 = Experts(expert_module, 1, **expert_factor)
exp1 = Experts(expert_module, 2, **expert_factor)
exp2 = Experts(expert_module, 4, **expert_factor)
exp3 = Experts(expert_module, 8, **expert_factor)
assert exp0.num_local_experts == 1
assert exp1.num_local_experts == 1
assert exp2.num_local_experts == 1
assert exp3.num_local_experts == 2
# experts deployment passed
dist_dict = MOE_CONTEXT.information
rank = dist.get_rank()
assert len(dist_dict) == 3
assert dist.get_rank(dist_dict[4].ep_group) == rank
assert dist.get_rank(dist_dict[2].ep_group) == rank % 2
assert dist.get_rank(dist_dict[1].ep_group) == 0
assert dist.get_rank(dist_dict[4].dp_group) == 0
assert dist.get_rank(dist_dict[2].dp_group) == rank // 2
assert dist.get_rank(dist_dict[1].dp_group) == rank
# group creation passed
model = nn.ModuleList([exp0, exp1, exp2, exp3])
model = model.to(get_current_device())
sync_moe_model_param(model)
assert_equal_in_group(exp0.experts[0].weight.data, dist_dict[1].dp_group)
assert_equal_in_group(exp0.experts[0].bias.data, dist_dict[1].dp_group)
# MOE experts layout success when ep_size = 1
assert_equal_in_group(exp1.experts[0].weight.data, dist_dict[2].dp_group)
assert_equal_in_group(exp1.experts[0].bias.data, dist_dict[2].dp_group)
# MOE experts layout success when ep_size = 2
@pytest.mark.dist
def test_moe_initialization():
world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_initialization()

View File

@ -1,97 +0,0 @@
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top1Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.skip(reason="Should be activated for detailed tests")
@pytest.mark.parametrize("rs", [2, 42, 60])
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(60, 512, torch.float16)

View File

@ -1,97 +0,0 @@
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.skip(reason="Should be activated for detailed tests")
@pytest.mark.parametrize("rs", [2, 42, 60])
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(2, 256, torch.float16)