mirror of https://github.com/hpcaitech/ColossalAI
[zero] add zero context manager to change config during initialization (#546)
parent
ec5086c49c
commit
8c90d4df54
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
|||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.zero.init_ctx import no_shard_zero_decrator
|
||||
from typing import Type
|
||||
|
||||
|
||||
|
@ -34,6 +35,7 @@ class Experts(MoeExperts):
|
|||
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
|
||||
"""
|
||||
|
||||
@no_shard_zero_decrator
|
||||
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
|
||||
super().__init__("all_to_all", num_experts)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
@ -9,6 +10,7 @@ from colossalai.utils import get_current_device
|
|||
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .experts import MoeExperts, Experts
|
||||
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator
|
||||
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
||||
from typing import Callable, Optional, Type
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
@ -205,7 +207,7 @@ class Top2Router(nn.Module):
|
|||
return cb_weight, sec_mask
|
||||
|
||||
|
||||
class FP32LinearGate(nn.Linear):
|
||||
class FP32LinearGate(nn.Module):
|
||||
"""Gate module used in MOE layer. Just a linear function without bias.
|
||||
But it should be kept as fp32 forever.
|
||||
|
||||
|
@ -217,9 +219,13 @@ class FP32LinearGate(nn.Linear):
|
|||
weight (ForceFP32Parameter): The weight of linear gate
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, num_experts: int):
|
||||
super().__init__(d_model, num_experts, bias=False, device=get_current_device())
|
||||
self.weight = ForceFP32Parameter(self.weight)
|
||||
def __init__(self, d_model: int, num_experts: int, scale: float = 0.1):
|
||||
super().__init__()
|
||||
self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device()))
|
||||
nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return F.linear(x, self.weight)
|
||||
|
||||
|
||||
class MoeLayer(nn.Module):
|
||||
|
@ -235,6 +241,7 @@ class MoeLayer(nn.Module):
|
|||
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
|
||||
"""
|
||||
|
||||
@no_shard_zero_decrator
|
||||
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
|
||||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
|
@ -361,7 +368,6 @@ class MoeModule(nn.Module):
|
|||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
|
||||
self.use_residual = use_residual
|
||||
if use_residual:
|
||||
if residual_instance is not None:
|
||||
|
@ -371,7 +377,8 @@ class MoeModule(nn.Module):
|
|||
"Expert class can't be None when residual instance is not given"
|
||||
self.residual_module = expert_cls(**expert_args)
|
||||
|
||||
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
||||
with no_shard_zero_context():
|
||||
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
||||
|
||||
if expert_instance is not None:
|
||||
self.experts = expert_instance
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .init_context import ZeroInitContext
|
||||
from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
|
||||
|
||||
__all__ = ['ZeroInitContext']
|
||||
__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator']
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import contextlib
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
|
@ -82,6 +84,25 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||
pass
|
||||
|
||||
|
||||
class ZeroContextConfig(object):
|
||||
"""The configuration used to control zero context initialization.
|
||||
|
||||
Args:
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||
If set to `False`, remove tensor payload on param.data afther the context exist.
|
||||
This is used when you add some logic to operate tensors in __init__ of module.
|
||||
See torchvision resnet18. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
|
||||
super().__init__()
|
||||
self.shard_param: bool = shard_param
|
||||
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
|
||||
|
||||
|
||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""A context to initialize model.
|
||||
|
||||
|
@ -90,11 +111,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
3. Shard the param and grad according to flags.
|
||||
|
||||
Args:
|
||||
convert_fp16 (bool): Whether to convert params to fp16.
|
||||
target_device (torch.device): The device where param data after exiting the context.
|
||||
shard_strategy (BaseShardStrategy): Shard strategy instance.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
|
||||
This will reduce memory usage when initializing model.
|
||||
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
|
||||
|
@ -115,13 +134,23 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
super().__init__()
|
||||
self.target_device = target_device
|
||||
self.shard_param = shard_param
|
||||
self.shard_strategy = shard_strategy
|
||||
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
|
||||
self.initialized_param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(shard_param=shard_param,
|
||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
|
||||
ZeroContextMgr().current_context = self
|
||||
|
||||
@property
|
||||
def shard_param(self):
|
||||
return self.config.shard_param
|
||||
|
||||
@property
|
||||
def rm_torch_payload_on_the_fly(self):
|
||||
return self.config.rm_torch_payload_on_the_fly
|
||||
|
||||
def _pre_context_exec(self):
|
||||
"""
|
||||
The Callback function when entering the context
|
||||
|
@ -143,6 +172,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
The function to call at the end of the constructor of each module.
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
|
||||
def half_fn(t: torch.Tensor):
|
||||
return t.half() if t.is_floating_point() else t
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
if hasattr(param, 'col_attr'):
|
||||
|
@ -150,23 +183,24 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
self.model_numel_tensor += param.numel()
|
||||
|
||||
target_device = self.target_device
|
||||
|
||||
# convert to fp16
|
||||
param.data = param.data.to(torch.half)
|
||||
# convert parameters to half
|
||||
param_half = half_fn(param)
|
||||
param.data = param_half
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(torch.half)
|
||||
grad_half = half_fn(param.grad)
|
||||
param.grad.data = grad_half
|
||||
|
||||
# move torch parameters to the target device
|
||||
target_device = self.target_device
|
||||
param.data = param.data.to(target_device)
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(target_device)
|
||||
|
||||
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||
|
||||
self.initialized_param_list.append(param)
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.initialized_param_list.append(param)
|
||||
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
|
@ -174,3 +208,30 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
for buffer in module.buffers(recurse=False):
|
||||
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
||||
buffer.data = cast_tensor_to_fp16(buffer.data)
|
||||
|
||||
|
||||
class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
current_context: Optional[ZeroInitContext] = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hijack_context_config(self, **kwargs):
|
||||
if self.current_context is None:
|
||||
yield
|
||||
else:
|
||||
old_config = self.current_context.config
|
||||
self.current_context.config = ZeroContextConfig(**kwargs)
|
||||
yield
|
||||
self.current_context.config = old_config
|
||||
|
||||
|
||||
def no_shard_zero_context():
|
||||
return ZeroContextMgr().hijack_context_config(shard_param=False, rm_torch_payload_on_the_fly=False)
|
||||
|
||||
|
||||
def no_shard_zero_decrator(init_func):
|
||||
|
||||
def _no_shard(*args, **kwargs):
|
||||
with no_shard_zero_context():
|
||||
init_func(*args, **kwargs)
|
||||
|
||||
return _no_shard
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.context import MOE_CONTEXT
|
||||
from colossalai.nn.layer import MoeModule
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_zero_data_parallel.common import CONFIG
|
||||
|
||||
|
||||
class MoeModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.proj1 = nn.Linear(4, 8)
|
||||
expert_cls = nn.Linear
|
||||
expert_args_dict = dict(in_features=8, out_features=8)
|
||||
self.moe = MoeModule(dim_model=8,
|
||||
num_experts=8,
|
||||
noisy_policy='Jitter',
|
||||
use_residual=True,
|
||||
expert_cls=expert_cls,
|
||||
**expert_args_dict)
|
||||
self.proj2 = nn.Linear(8, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.moe(x)
|
||||
x = self.proj2(x)
|
||||
return x
|
||||
|
||||
|
||||
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_moe_zero_init(init_device_type, shard_strategy_class):
|
||||
logger = get_dist_logger("test_moe_zero_init")
|
||||
|
||||
if init_device_type == 'cuda':
|
||||
init_device = torch.device(f"cuda:{get_current_device()}")
|
||||
elif init_device_type == 'cpu':
|
||||
init_device = torch.device("cpu")
|
||||
else:
|
||||
raise NotImplementedError("Unknown device found.")
|
||||
|
||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||
with ZeroInitContext(target_device=init_device,
|
||||
shard_strategy=shard_strategy_class(),
|
||||
shard_param=True,
|
||||
model_numel_tensor=model_numel_tensor,
|
||||
rm_torch_payload_on_the_fly=False):
|
||||
model = MoeModel()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
|
||||
# the weights in the gate should be fp32
|
||||
if 'gate' in name:
|
||||
assert param.col_attr.sharded_data_tensor.dtype == torch.float32
|
||||
else:
|
||||
assert param.col_attr.sharded_data_tensor.dtype == torch.half
|
||||
|
||||
# the parameters in moe experts and its gate should not be sharded
|
||||
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
|
||||
assert not param.col_attr.sharded_data_tensor.is_sharded
|
||||
else:
|
||||
assert param.col_attr.sharded_data_tensor.is_sharded
|
||||
|
||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
MOE_CONTEXT.setup(seed=42)
|
||||
run_moe_zero_init()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_moe_zero_init(world_size):
|
||||
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_zero_init(world_size=2)
|
Loading…
Reference in New Issue