[zero] add zero context manager to change config during initialization (#546)

pull/501/head
HELSON 2022-03-29 17:57:59 +08:00 committed by GitHub
parent ec5086c49c
commit 8c90d4df54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 185 additions and 18 deletions

View File

@ -5,6 +5,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.zero.init_ctx import no_shard_zero_decrator
from typing import Type
@ -34,6 +35,7 @@ class Experts(MoeExperts):
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
@no_shard_zero_decrator
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts)

View File

@ -1,3 +1,4 @@
import functools
import math
import torch
@ -9,6 +10,7 @@ from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts, Experts
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
from typing import Callable, Optional, Type
from torch.distributed import ProcessGroup
@ -205,7 +207,7 @@ class Top2Router(nn.Module):
return cb_weight, sec_mask
class FP32LinearGate(nn.Linear):
class FP32LinearGate(nn.Module):
"""Gate module used in MOE layer. Just a linear function without bias.
But it should be kept as fp32 forever.
@ -217,9 +219,13 @@ class FP32LinearGate(nn.Linear):
weight (ForceFP32Parameter): The weight of linear gate
"""
def __init__(self, d_model: int, num_experts: int):
super().__init__(d_model, num_experts, bias=False, device=get_current_device())
self.weight = ForceFP32Parameter(self.weight)
def __init__(self, d_model: int, num_experts: int, scale: float = 0.1):
super().__init__()
self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device()))
nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model))
def forward(self, x: torch.Tensor):
return F.linear(x, self.weight)
class MoeLayer(nn.Module):
@ -235,6 +241,7 @@ class MoeLayer(nn.Module):
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
"""
@no_shard_zero_decrator
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
super().__init__()
self.d_model = dim_model
@ -361,7 +368,6 @@ class MoeModule(nn.Module):
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
self.use_residual = use_residual
if use_residual:
if residual_instance is not None:
@ -371,7 +377,8 @@ class MoeModule(nn.Module):
"Expert class can't be None when residual instance is not given"
self.residual_module = expert_cls(**expert_args)
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
with no_shard_zero_context():
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
if expert_instance is not None:
self.experts = expert_instance

View File

@ -1,3 +1,3 @@
from .init_context import ZeroInitContext
from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
__all__ = ['ZeroInitContext']
__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator']

View File

@ -1,9 +1,11 @@
import contextlib
import functools
from typing import Optional
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.logging import get_dist_logger
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
@ -82,6 +84,25 @@ class InsertPostInitMethodToModuleSubClasses(object):
pass
class ZeroContextConfig(object):
"""The configuration used to control zero context initialization.
Args:
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18. Defaults to False.
"""
def __init__(self, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
super().__init__()
self.shard_param: bool = shard_param
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model.
@ -90,11 +111,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags.
Args:
convert_fp16 (bool): Whether to convert params to fp16.
target_device (torch.device): The device where param data after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
@ -115,13 +134,23 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__()
self.target_device = target_device
self.shard_param = shard_param
self.shard_strategy = shard_strategy
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
self.initialized_param_list = []
self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(shard_param=shard_param,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
ZeroContextMgr().current_context = self
@property
def shard_param(self):
return self.config.shard_param
@property
def rm_torch_payload_on_the_fly(self):
return self.config.rm_torch_payload_on_the_fly
def _pre_context_exec(self):
"""
The Callback function when entering the context
@ -143,6 +172,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
if hasattr(param, 'col_attr'):
@ -150,23 +183,24 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.model_numel_tensor += param.numel()
target_device = self.target_device
# convert to fp16
param.data = param.data.to(torch.half)
# convert parameters to half
param_half = half_fn(param)
param.data = param_half
if param.grad is not None:
param.grad = param.grad.to(torch.half)
grad_half = half_fn(param.grad)
param.grad.data = grad_half
# move torch parameters to the target device
target_device = self.target_device
param.data = param.data.to(target_device)
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
self.initialized_param_list.append(param)
if self.shard_param:
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
self.initialized_param_list.append(param)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
@ -174,3 +208,30 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
for buffer in module.buffers(recurse=False):
buffer.data = buffer.data.to(device=torch.cuda.current_device())
buffer.data = cast_tensor_to_fp16(buffer.data)
class ZeroContextMgr(metaclass=SingletonMeta):
current_context: Optional[ZeroInitContext] = None
@contextlib.contextmanager
def hijack_context_config(self, **kwargs):
if self.current_context is None:
yield
else:
old_config = self.current_context.config
self.current_context.config = ZeroContextConfig(**kwargs)
yield
self.current_context.config = old_config
def no_shard_zero_context():
return ZeroContextMgr().hijack_context_config(shard_param=False, rm_torch_payload_on_the_fly=False)
def no_shard_zero_decrator(init_func):
def _no_shard(*args, **kwargs):
with no_shard_zero_context():
init_func(*args, **kwargs)
return _no_shard

View File

@ -0,0 +1,97 @@
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.context import MOE_CONTEXT
from colossalai.nn.layer import MoeModule
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.testing import rerun_on_exception
from colossalai.utils import get_current_device
from tests.test_zero_data_parallel.common import CONFIG
class MoeModel(nn.Module):
def __init__(self):
super().__init__()
self.proj1 = nn.Linear(4, 8)
expert_cls = nn.Linear
expert_args_dict = dict(in_features=8, out_features=8)
self.moe = MoeModule(dim_model=8,
num_experts=8,
noisy_policy='Jitter',
use_residual=True,
expert_cls=expert_cls,
**expert_args_dict)
self.proj2 = nn.Linear(8, 4)
def forward(self, x):
x = self.proj(x)
x = self.moe(x)
x = self.proj2(x)
return x
@parameterize("init_device_type", ['cpu', 'cuda'])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_moe_zero_init(init_device_type, shard_strategy_class):
logger = get_dist_logger("test_moe_zero_init")
if init_device_type == 'cuda':
init_device = torch.device(f"cuda:{get_current_device()}")
elif init_device_type == 'cpu':
init_device = torch.device("cpu")
else:
raise NotImplementedError("Unknown device found.")
model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(target_device=init_device,
shard_strategy=shard_strategy_class(),
shard_param=True,
model_numel_tensor=model_numel_tensor,
rm_torch_payload_on_the_fly=False):
model = MoeModel()
for name, param in model.named_parameters():
assert hasattr(param, 'col_attr')
# the weights in the gate should be fp32
if 'gate' in name:
assert param.col_attr.sharded_data_tensor.dtype == torch.float32
else:
assert param.col_attr.sharded_data_tensor.dtype == torch.half
# the parameters in moe experts and its gate should not be sharded
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
assert not param.col_attr.sharded_data_tensor.is_sharded
else:
assert param.col_attr.sharded_data_tensor.is_sharded
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
run_moe_zero_init()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_moe_zero_init(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_zero_init(world_size=2)