diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/engine/gradient_handler/_moe_gradient_handler.py index f65be3869..4cc411a78 100644 --- a/colossalai/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_moe_gradient_handler.py @@ -16,6 +16,9 @@ class MoeGradientHandler(BaseGradientHandler): the same type to improve the efficiency of communication. """ + def __init__(self, model, optimizer=None): + super().__init__(model, optimizer) + def handle_gradient(self): """A method running an all-reduce operation in a data parallel group. Then running an all-reduce operation for all parameters in experts @@ -24,13 +27,15 @@ class MoeGradientHandler(BaseGradientHandler): global_data = gpc.data_parallel_size if global_data > 1: - param_dict = get_moe_epsize_param_dict(self._model) + epsize_param_dict = get_moe_epsize_param_dict(self._model) + # epsize is 1, indicating the params are replicated among processes in data parallelism + # use the ParallelMode.DATA to get data parallel group # reduce gradients for all parameters in data parallelism - if 1 in param_dict: - bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA)) + if 1 in epsize_param_dict: + bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) - for ep_size in param_dict: + for ep_size in epsize_param_dict: if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - bucket_allreduce(param_list=param_dict[ep_size], + bucket_allreduce(param_list=epsize_param_dict[ep_size], group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) diff --git a/colossalai/engine/ophooks/__init__.py b/colossalai/engine/ophooks/__init__.py index 8a1071e38..412df33c3 100644 --- a/colossalai/engine/ophooks/__init__.py +++ b/colossalai/engine/ophooks/__init__.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Callable, Optional import torch @@ -61,18 +61,25 @@ class PostBackwardFunction(torch.autograd.Function): return (None, None) + args -def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""): +def register_ophooks_recursively(module: torch.nn.Module, + ophook_list: List[BaseOpHook] = None, + name: str = "", + filter_fn: Optional[Callable] = None): r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" assert isinstance(module, torch.nn.Module) # Add hooks for submodules for child_name, child in module.named_children(): - register_ophooks_recursively(child, ophook_list, name + child_name) + register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn) # Early return on modules with no parameters. if len(list(module.parameters(recurse=False))) == 0: return + # return from flitered module + if filter_fn is not None and filter_fn(module): + return + if ophook_list is not None: for hook in ophook_list: assert (isinstance(hook, BaseOpHook)) diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index a23b09b12..367278d4d 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -35,7 +35,7 @@ class Experts(MoeExperts): expert_args: Args used to initialize experts, the args could be found in corresponding expert class """ - @no_shard_zero_decrator + @no_shard_zero_decrator(is_replicated=False) def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): super().__init__("all_to_all", num_experts) diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index d518ba3f2..04b4f1a58 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -9,7 +9,7 @@ from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils import get_current_device from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from .experts import MoeExperts, Experts -from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator +from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator, autocast_softmax from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator from typing import Callable, Optional, Type from torch.distributed import ProcessGroup @@ -66,7 +66,7 @@ class Top1Router(nn.Module): if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) - logits = F.softmax(inputs, dim=-1) + logits = autocast_softmax(inputs, dim=-1) num_experts = logits.size(-1) capacity = self.get_capacity(logits.shape) @@ -152,7 +152,7 @@ class Top2Router(nn.Module): if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) - logits = F.softmax(inputs, dim=-1) # logits: [s, e] + logits = autocast_softmax(inputs, dim=-1) # logits: [s, e] num_experts = logits.size(-1) capacity = self.get_capacity(logits.shape) @@ -241,7 +241,7 @@ class MoeLayer(nn.Module): experts (:class:`torch.nn.Module`): Instance of experts generated by Expert. """ - @no_shard_zero_decrator + @no_shard_zero_decrator(is_replicated=True) def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts): super().__init__() self.d_model = dim_model diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 06913dc9a..a13c8184e 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F from colossalai.utils import get_current_device from colossalai.context.moe_context import MOE_CONTEXT from .experts import FFNExperts, TPExperts @@ -51,6 +52,12 @@ class UniformNoiseGenerator: return inputs * noisy +def autocast_softmax(logit: torch.Tensor, dim: int): + if logit.dtype != torch.float32: + logit = logit.float() + return F.softmax(logit, dim=dim) + + def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): mep_size = MOE_CONTEXT.max_ep_size if num_experts % mep_size == 0 or mep_size % num_experts == 0: diff --git a/colossalai/utils/memory_tracer/model_data_memtracer.py b/colossalai/utils/memory_tracer/model_data_memtracer.py index e38587367..31888f7f1 100644 --- a/colossalai/utils/memory_tracer/model_data_memtracer.py +++ b/colossalai/utils/memory_tracer/model_data_memtracer.py @@ -39,7 +39,7 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]: if t.device.type == 'cpu': _cpu_mem_usage += t.numel() * t.element_size() elif t.device.type == 'cuda': - _cuda_mem_usages += t.numel() * t.element_size() + _cuda_mem_usage += t.numel() * t.element_size() return _cuda_mem_usage, _cpu_mem_usage cuda_mem_usage = 0 diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index db6431f7d..4be143c15 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -88,6 +88,8 @@ class ZeroContextConfig(object): """The configuration used to control zero context initialization. Args: + replicated (bool, optional): Whether the param is replicated across data parallel group. + Some parameters are not replicated, e.g. parameters in MOE experts. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. This will reduce memory usage when initializing model. @@ -97,8 +99,9 @@ class ZeroContextConfig(object): See torchvision resnet18. Defaults to False. """ - def __init__(self, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False): + def __init__(self, replicated: bool = True, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False): super().__init__() + self.is_replicated: bool = replicated self.shard_param: bool = shard_param self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly @@ -139,10 +142,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.model_numel_tensor = model_numel_tensor self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) - self.config = ZeroContextConfig(shard_param=shard_param, + self.config = ZeroContextConfig(replicated=True, + shard_param=shard_param, rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly) ZeroContextMgr().current_context = self + @property + def is_replicated(self): + return self.config.is_replicated + @property def shard_param(self): return self.config.shard_param @@ -183,6 +191,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.model_numel_tensor += param.numel() + # mark whether the param is replicated + param.is_replicated = self.is_replicated + # convert parameters to half param_half = half_fn(param) param.data = param_half @@ -224,14 +235,20 @@ class ZeroContextMgr(metaclass=SingletonMeta): self.current_context.config = old_config -def no_shard_zero_context(): - return ZeroContextMgr().hijack_context_config(shard_param=False, rm_torch_payload_on_the_fly=False) +def no_shard_zero_context(is_replicated: bool = True): + return ZeroContextMgr().hijack_context_config(replicated=is_replicated, + shard_param=False, + rm_torch_payload_on_the_fly=False) + + +def no_shard_zero_decrator(is_replicated: bool = True): + def _wrapper(init_func): -def no_shard_zero_decrator(init_func): + def _no_shard(*args, **kwargs): + with no_shard_zero_context(is_replicated): + init_func(*args, **kwargs) - def _no_shard(*args, **kwargs): - with no_shard_zero_context(): - init_func(*args, **kwargs) + return _no_shard - return _no_shard + return _wrapper diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 1e60be6da..ce87c1a86 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -10,6 +10,7 @@ from colossalai.core import global_context as gpc from colossalai.engine.ophooks import register_ophooks_recursively from colossalai.engine.ophooks.zero_hook import ZeroHook from colossalai.engine.paramhooks import BaseParamHookMgr +from colossalai.engine.gradient_handler.utils import bucket_allreduce from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector @@ -67,17 +68,27 @@ class ShardedModelV2(nn.Module): self.logger = get_dist_logger() # We force users to use ZeroInitContext - sharded = [] - unsharded = [] + for submodule in module.modules(): + sharded_cnt = 0 + unshard_cnt = 0 + for param in submodule.parameters(recurse=False): + assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.' + if param.colo_attr.param_is_sharded: + sharded_cnt += 1 + else: + unshard_cnt += 1 + assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param' + submodule.param_is_sharded = (sharded_cnt > 0) + + self.sharded_params = [] + self.unshard_params = [] for param in module.parameters(): - assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.' - sharded.append(param.colo_attr.param_is_sharded) - unsharded.append(not param.colo_attr.param_is_sharded) - assert all(sharded) or all( - unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.' - self.shard_param = all(sharded) - self.module = module + if param.colo_attr.param_is_sharded: + self.sharded_params.append(param) + else: + self.unshard_params.append(param) + self.module = module self.process_group = process_group or gpc.get_group(ParallelMode.DATA) self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group self.world_size = dist.get_world_size(self.process_group) @@ -95,8 +106,8 @@ class ShardedModelV2(nn.Module): # Register hooks self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)] - register_ophooks_recursively(self.module, self._ophook_list) - self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) + register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded) + self.param_hook_mgr = BaseParamHookMgr(self.sharded_params) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) self.fp32_reduce_scatter = fp32_reduce_scatter @@ -185,7 +196,6 @@ class ShardedModelV2(nn.Module): def backward_by_grad(self, tensor, grad): torch.autograd.backward(tensors=tensor, grad_tensors=grad) - self._post_backward_operations() for ophook in self._ophook_list: ophook.post_iter() @@ -224,17 +234,21 @@ class ShardedModelV2(nn.Module): # Wait for the non-blocking GPU -> CPU grad transfers to finish. torch.cuda.current_stream().synchronize() self.reducer.free() + + # all reduce gradients for unsharded parameters + reduce_list = [p for p in self.unshard_params if p.is_replicated] + bucket_allreduce(reduce_list, self.process_group) + # 3. shard tensors not dealed in the zero hook - if self.shard_param: - tensor_list = [] - for p in self.module.parameters(): - if not p.colo_attr.param_is_sharded: - tensor_list.append(p.colo_attr.sharded_data_tensor) - p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) - p.colo_attr.remove_torch_payload() - self.shard_strategy.shard(tensor_list, self.process_group) - - # 4. move sharded param grad payload to param.grad + tensor_list = [] + for p in self.sharded_params: + if not p.colo_attr.param_is_sharded: + tensor_list.append(p.colo_attr.sharded_data_tensor) + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) + p.colo_attr.remove_torch_payload() + self.shard_strategy.shard(tensor_list, self.process_group) + + # 4. set all parameters' grad to None for p in self.module.parameters(): if not p.requires_grad: continue @@ -245,6 +259,16 @@ class ShardedModelV2(nn.Module): # We also allows to interleave no-sync pass with sync passes, if desired. if not self._require_backward_grad_sync: continue + + # move unsharded param grad to saved_grad + if not p.colo_attr.param_is_sharded: + if p.colo_attr.offload_grad: + colo_model_data_move_to_cpu(p.grad) + if p.colo_attr.saved_grad.is_null(): + p.colo_attr.saved_grad.reset_payload(p.grad.data) + else: + p.colo_attr.saved_grad.payload.add_(p.grad.data) + p.grad = None @torch.no_grad() @@ -320,16 +344,14 @@ class ShardedModelV2(nn.Module): param.colo_attr.saved_grad.trans_state(TensorState.HOLD) def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': - self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()], - self.process_group) + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group) prev_params = {} - for p in self.module.parameters(): + for p in self.sharded_params: prev_params[p] = p.data p.data = p.colo_attr.sharded_data_tensor.payload gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars) - self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.module.parameters()], - self.process_group) - for p in self.module.parameters(): + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group) + for p in self.sharded_params: p.data = prev_params[p] return gathered_state_dict diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index d8f46ddea..0fa5067a5 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -22,19 +22,14 @@ class MoeModel(nn.Module): def __init__(self): super().__init__() - self.proj1 = nn.Linear(4, 8) + self.proj1 = nn.Linear(4, 16) expert_cls = nn.Linear - expert_args_dict = dict(in_features=8, out_features=8) - self.moe = MoeModule(dim_model=8, - num_experts=8, - noisy_policy='Jitter', - use_residual=True, - expert_cls=expert_cls, - **expert_args_dict) - self.proj2 = nn.Linear(8, 4) + expert_args_dict = dict(in_features=16, out_features=16) + self.moe = MoeModule(dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict) + self.proj2 = nn.Linear(16, 4) def forward(self, x): - x = self.proj(x) + x = self.proj1(x) x = self.moe(x) x = self.proj2(x) return x @@ -75,6 +70,12 @@ def run_moe_zero_init(init_device_type, shard_strategy_class): else: assert param.colo_attr.sharded_data_tensor.is_sharded + # the parameters in moe experts is not replicated + if 'experts' in name: + assert not param.is_replicated + else: + assert param.is_replicated + assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \ f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py new file mode 100644 index 000000000..5b5a73d05 --- /dev/null +++ b/tests/test_moe/test_moe_zero_model.py @@ -0,0 +1,78 @@ +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.testing import parameterize, rerun_on_exception +from colossalai.utils import free_port +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.sharded_model.utils import col_model_deepcopy +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.context import MOE_CONTEXT +from colossalai.testing import assert_equal_in_group + +from tests.test_zero_data_parallel.common import CONFIG, check_grads_padding, run_fwd_bwd +from tests.test_moe.test_moe_zero_init import MoeModel + + +@parameterize("enable_autocast", [False]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_model_test(enable_autocast, shard_strategy_class): + shard_strategy = shard_strategy_class() + + get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module') + _, train_dataloader, _, _, criterion = get_components_func() + + rm_torch_payload_on_the_fly = False + + with ZeroInitContext(target_device=torch.cuda.current_device(), + shard_strategy=shard_strategy, + shard_param=True, + rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly): + zero_model = MoeModel() + zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True) + + # check whether parameters are identical in ddp + for name, p in zero_model.named_parameters(): + if not p.colo_attr.param_is_sharded and p.is_replicated: + assert_equal_in_group(p.data) + + model = MoeModel().half() + col_model_deepcopy(zero_model, model) + model = model.cuda() + grad_handler = MoeGradientHandler(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + + data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() + run_fwd_bwd(model, data, label, criterion, enable_autocast) + run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) + grad_handler.handle_gradient() + + check_grads_padding(model, zero_model, loose=True) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + MOE_CONTEXT.reset_loss() + run_model_test() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_moe_zero_model(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_model(world_size=2) diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index 2abc8c53d..c7948c0fe 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -91,15 +91,19 @@ def check_params(model, zero_model, loose=False): def check_grads_padding(model, zero_model, loose=False): rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): + for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): # zero_grad = zero_p.grad.clone().to(p.device) - zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device) - chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - grad = chunks[rank].float() - if zero_grad.size(0) > grad.size(0): - zero_grad = zero_grad[:grad.size(0)] + if zero_p.colo_attr.param_is_sharded: + zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device) + chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + grad = chunks[rank].float() + if zero_grad.size(0) > grad.size(0): + zero_grad = zero_grad[:grad.size(0)] + else: + grad = p.grad + zero_grad = zero_p.colo_attr.saved_grad.payload assert grad.dtype == zero_grad.dtype assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'