From 8c90d4df545957fca06d0ef8201ba9f0a40d06b7 Mon Sep 17 00:00:00 2001 From: HELSON Date: Tue, 29 Mar 2022 17:57:59 +0800 Subject: [PATCH] [zero] add zero context manager to change config during initialization (#546) --- colossalai/nn/layer/moe/experts.py | 2 + colossalai/nn/layer/moe/layers.py | 19 +++-- colossalai/zero/init_ctx/__init__.py | 4 +- colossalai/zero/init_ctx/init_context.py | 81 +++++++++++++++++--- tests/test_moe/test_moe_zero_init.py | 97 ++++++++++++++++++++++++ 5 files changed, 185 insertions(+), 18 deletions(-) create mode 100644 tests/test_moe/test_moe_zero_init.py diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 652a8321a..a23b09b12 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -5,6 +5,7 @@ import torch.nn as nn from colossalai.context import ParallelMode, seed from colossalai.utils import get_current_device from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.zero.init_ctx import no_shard_zero_decrator from typing import Type @@ -34,6 +35,7 @@ class Experts(MoeExperts): expert_args: Args used to initialize experts, the args could be found in corresponding expert class """ + @no_shard_zero_decrator def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): super().__init__("all_to_all", num_experts) diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index d94cf986c..d518ba3f2 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -1,3 +1,4 @@ +import functools import math import torch @@ -9,6 +10,7 @@ from colossalai.utils import get_current_device from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from .experts import MoeExperts, Experts from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator +from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator from typing import Callable, Optional, Type from torch.distributed import ProcessGroup @@ -205,7 +207,7 @@ class Top2Router(nn.Module): return cb_weight, sec_mask -class FP32LinearGate(nn.Linear): +class FP32LinearGate(nn.Module): """Gate module used in MOE layer. Just a linear function without bias. But it should be kept as fp32 forever. @@ -217,9 +219,13 @@ class FP32LinearGate(nn.Linear): weight (ForceFP32Parameter): The weight of linear gate """ - def __init__(self, d_model: int, num_experts: int): - super().__init__(d_model, num_experts, bias=False, device=get_current_device()) - self.weight = ForceFP32Parameter(self.weight) + def __init__(self, d_model: int, num_experts: int, scale: float = 0.1): + super().__init__() + self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device())) + nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model)) + + def forward(self, x: torch.Tensor): + return F.linear(x, self.weight) class MoeLayer(nn.Module): @@ -235,6 +241,7 @@ class MoeLayer(nn.Module): experts (:class:`torch.nn.Module`): Instance of experts generated by Expert. """ + @no_shard_zero_decrator def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts): super().__init__() self.d_model = dim_model @@ -361,7 +368,6 @@ class MoeModule(nn.Module): min_capacity=min_capacity, noisy_func=noisy_func, drop_tks=drop_tks) - self.use_residual = use_residual if use_residual: if residual_instance is not None: @@ -371,7 +377,8 @@ class MoeModule(nn.Module): "Expert class can't be None when residual instance is not given" self.residual_module = expert_cls(**expert_args) - self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) + with no_shard_zero_context(): + self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) if expert_instance is not None: self.experts = expert_instance diff --git a/colossalai/zero/init_ctx/__init__.py b/colossalai/zero/init_ctx/__init__.py index 804b36b02..0a6f81566 100644 --- a/colossalai/zero/init_ctx/__init__.py +++ b/colossalai/zero/init_ctx/__init__.py @@ -1,3 +1,3 @@ -from .init_context import ZeroInitContext +from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator -__all__ = ['ZeroInitContext'] +__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator'] diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 9d5812d1a..e093ea8db 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -1,9 +1,11 @@ +import contextlib import functools from typing import Optional import torch from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.context.singleton_meta import SingletonMeta from colossalai.logging import get_dist_logger from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 @@ -82,6 +84,25 @@ class InsertPostInitMethodToModuleSubClasses(object): pass +class ZeroContextConfig(object): + """The configuration used to control zero context initialization. + + Args: + shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. + rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. + This will reduce memory usage when initializing model. + But it's not suitable for all models, especially when there are `weight init` operations in `__init__`. + If set to `False`, remove tensor payload on param.data afther the context exist. + This is used when you add some logic to operate tensors in __init__ of module. + See torchvision resnet18. Defaults to False. + """ + + def __init__(self, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False): + super().__init__() + self.shard_param: bool = shard_param + self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly + + class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): """A context to initialize model. @@ -90,11 +111,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): 3. Shard the param and grad according to flags. Args: - convert_fp16 (bool): Whether to convert params to fp16. target_device (torch.device): The device where param data after exiting the context. shard_strategy (BaseShardStrategy): Shard strategy instance. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. - shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False. rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. This will reduce memory usage when initializing model. But it's not suitable for all models, especially when there are `weight init` operations in `__init__`. @@ -115,13 +134,23 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): super().__init__() self.target_device = target_device - self.shard_param = shard_param self.shard_strategy = shard_strategy - self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly self.initialized_param_list = [] self.model_numel_tensor = model_numel_tensor self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) + self.config = ZeroContextConfig(shard_param=shard_param, + rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly) + ZeroContextMgr().current_context = self + + @property + def shard_param(self): + return self.config.shard_param + + @property + def rm_torch_payload_on_the_fly(self): + return self.config.rm_torch_payload_on_the_fly + def _pre_context_exec(self): """ The Callback function when entering the context @@ -143,6 +172,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): The function to call at the end of the constructor of each module. NOTE() The module may be passed to this function multiple times. """ + + def half_fn(t: torch.Tensor): + return t.half() if t.is_floating_point() else t + for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice if hasattr(param, 'col_attr'): @@ -150,23 +183,24 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.model_numel_tensor += param.numel() - target_device = self.target_device - - # convert to fp16 - param.data = param.data.to(torch.half) + # convert parameters to half + param_half = half_fn(param) + param.data = param_half if param.grad is not None: - param.grad = param.grad.to(torch.half) + grad_half = half_fn(param.grad) + param.grad.data = grad_half # move torch parameters to the target device + target_device = self.target_device param.data = param.data.to(target_device) if param.grad is not None: param.grad = param.grad.to(target_device) param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly) - self.initialized_param_list.append(param) if self.shard_param: self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group) + self.initialized_param_list.append(param) # We must cast buffers # If we use BN, buffers may be on CPU and Float @@ -174,3 +208,30 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): for buffer in module.buffers(recurse=False): buffer.data = buffer.data.to(device=torch.cuda.current_device()) buffer.data = cast_tensor_to_fp16(buffer.data) + + +class ZeroContextMgr(metaclass=SingletonMeta): + current_context: Optional[ZeroInitContext] = None + + @contextlib.contextmanager + def hijack_context_config(self, **kwargs): + if self.current_context is None: + yield + else: + old_config = self.current_context.config + self.current_context.config = ZeroContextConfig(**kwargs) + yield + self.current_context.config = old_config + + +def no_shard_zero_context(): + return ZeroContextMgr().hijack_context_config(shard_param=False, rm_torch_payload_on_the_fly=False) + + +def no_shard_zero_decrator(init_func): + + def _no_shard(*args, **kwargs): + with no_shard_zero_context(): + init_func(*args, **kwargs) + + return _no_shard diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py new file mode 100644 index 000000000..0c4e66ffa --- /dev/null +++ b/tests/test_moe/test_moe_zero_init.py @@ -0,0 +1,97 @@ +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize +from colossalai.utils import free_port +from colossalai.context import MOE_CONTEXT +from colossalai.nn.layer import MoeModule +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) + +from colossalai.testing import rerun_on_exception +from colossalai.utils import get_current_device +from tests.test_zero_data_parallel.common import CONFIG + + +class MoeModel(nn.Module): + + def __init__(self): + super().__init__() + self.proj1 = nn.Linear(4, 8) + expert_cls = nn.Linear + expert_args_dict = dict(in_features=8, out_features=8) + self.moe = MoeModule(dim_model=8, + num_experts=8, + noisy_policy='Jitter', + use_residual=True, + expert_cls=expert_cls, + **expert_args_dict) + self.proj2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.proj(x) + x = self.moe(x) + x = self.proj2(x) + return x + + +@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_moe_zero_init(init_device_type, shard_strategy_class): + logger = get_dist_logger("test_moe_zero_init") + + if init_device_type == 'cuda': + init_device = torch.device(f"cuda:{get_current_device()}") + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + model_numel_tensor = torch.zeros(1, dtype=torch.int) + with ZeroInitContext(target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor, + rm_torch_payload_on_the_fly=False): + model = MoeModel() + + for name, param in model.named_parameters(): + assert hasattr(param, 'col_attr') + + # the weights in the gate should be fp32 + if 'gate' in name: + assert param.col_attr.sharded_data_tensor.dtype == torch.float32 + else: + assert param.col_attr.sharded_data_tensor.dtype == torch.half + + # the parameters in moe experts and its gate should not be sharded + if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): + assert not param.col_attr.sharded_data_tensor.is_sharded + else: + assert param.col_attr.sharded_data_tensor.is_sharded + + assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \ + f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_moe_zero_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_moe_zero_init(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_init(world_size=2)