[zero] add zero context manager to change config during initialization (#546)

pull/501/head
HELSON 2022-03-29 17:57:59 +08:00 committed by GitHub
parent ec5086c49c
commit 8c90d4df54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 185 additions and 18 deletions

View File

@ -5,6 +5,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.zero.init_ctx import no_shard_zero_decrator
from typing import Type from typing import Type
@ -34,6 +35,7 @@ class Experts(MoeExperts):
expert_args: Args used to initialize experts, the args could be found in corresponding expert class expert_args: Args used to initialize experts, the args could be found in corresponding expert class
""" """
@no_shard_zero_decrator
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts) super().__init__("all_to_all", num_experts)

View File

@ -1,3 +1,4 @@
import functools
import math import math
import torch import torch
@ -9,6 +10,7 @@ from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts, Experts from .experts import MoeExperts, Experts
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
from typing import Callable, Optional, Type from typing import Callable, Optional, Type
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -205,7 +207,7 @@ class Top2Router(nn.Module):
return cb_weight, sec_mask return cb_weight, sec_mask
class FP32LinearGate(nn.Linear): class FP32LinearGate(nn.Module):
"""Gate module used in MOE layer. Just a linear function without bias. """Gate module used in MOE layer. Just a linear function without bias.
But it should be kept as fp32 forever. But it should be kept as fp32 forever.
@ -217,9 +219,13 @@ class FP32LinearGate(nn.Linear):
weight (ForceFP32Parameter): The weight of linear gate weight (ForceFP32Parameter): The weight of linear gate
""" """
def __init__(self, d_model: int, num_experts: int): def __init__(self, d_model: int, num_experts: int, scale: float = 0.1):
super().__init__(d_model, num_experts, bias=False, device=get_current_device()) super().__init__()
self.weight = ForceFP32Parameter(self.weight) self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device()))
nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model))
def forward(self, x: torch.Tensor):
return F.linear(x, self.weight)
class MoeLayer(nn.Module): class MoeLayer(nn.Module):
@ -235,6 +241,7 @@ class MoeLayer(nn.Module):
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert. experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
""" """
@no_shard_zero_decrator
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts): def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
super().__init__() super().__init__()
self.d_model = dim_model self.d_model = dim_model
@ -361,7 +368,6 @@ class MoeModule(nn.Module):
min_capacity=min_capacity, min_capacity=min_capacity,
noisy_func=noisy_func, noisy_func=noisy_func,
drop_tks=drop_tks) drop_tks=drop_tks)
self.use_residual = use_residual self.use_residual = use_residual
if use_residual: if use_residual:
if residual_instance is not None: if residual_instance is not None:
@ -371,6 +377,7 @@ class MoeModule(nn.Module):
"Expert class can't be None when residual instance is not given" "Expert class can't be None when residual instance is not given"
self.residual_module = expert_cls(**expert_args) self.residual_module = expert_cls(**expert_args)
with no_shard_zero_context():
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
if expert_instance is not None: if expert_instance is not None:

View File

@ -1,3 +1,3 @@
from .init_context import ZeroInitContext from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
__all__ = ['ZeroInitContext'] __all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator']

View File

@ -1,9 +1,11 @@
import contextlib
import functools import functools
from typing import Optional from typing import Optional
import torch import torch
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
@ -82,6 +84,25 @@ class InsertPostInitMethodToModuleSubClasses(object):
pass pass
class ZeroContextConfig(object):
"""The configuration used to control zero context initialization.
Args:
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18. Defaults to False.
"""
def __init__(self, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
super().__init__()
self.shard_param: bool = shard_param
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model. """A context to initialize model.
@ -90,11 +111,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags. 3. Shard the param and grad according to flags.
Args: Args:
convert_fp16 (bool): Whether to convert params to fp16.
target_device (torch.device): The device where param data after exiting the context. target_device (torch.device): The device where param data after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance. shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model. This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`. But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
@ -115,13 +134,23 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__() super().__init__()
self.target_device = target_device self.target_device = target_device
self.shard_param = shard_param
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
self.initialized_param_list = [] self.initialized_param_list = []
self.model_numel_tensor = model_numel_tensor self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(shard_param=shard_param,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
ZeroContextMgr().current_context = self
@property
def shard_param(self):
return self.config.shard_param
@property
def rm_torch_payload_on_the_fly(self):
return self.config.rm_torch_payload_on_the_fly
def _pre_context_exec(self): def _pre_context_exec(self):
""" """
The Callback function when entering the context The Callback function when entering the context
@ -143,6 +172,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times. NOTE() The module may be passed to this function multiple times.
""" """
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice # avoid adapting a param to ShardedParam twice
if hasattr(param, 'col_attr'): if hasattr(param, 'col_attr'):
@ -150,23 +183,24 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.model_numel_tensor += param.numel() self.model_numel_tensor += param.numel()
target_device = self.target_device # convert parameters to half
param_half = half_fn(param)
# convert to fp16 param.data = param_half
param.data = param.data.to(torch.half)
if param.grad is not None: if param.grad is not None:
param.grad = param.grad.to(torch.half) grad_half = half_fn(param.grad)
param.grad.data = grad_half
# move torch parameters to the target device # move torch parameters to the target device
target_device = self.target_device
param.data = param.data.to(target_device) param.data = param.data.to(target_device)
if param.grad is not None: if param.grad is not None:
param.grad = param.grad.to(target_device) param.grad = param.grad.to(target_device)
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly) param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
self.initialized_param_list.append(param)
if self.shard_param: if self.shard_param:
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
self.initialized_param_list.append(param)
# We must cast buffers # We must cast buffers
# If we use BN, buffers may be on CPU and Float # If we use BN, buffers may be on CPU and Float
@ -174,3 +208,30 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
for buffer in module.buffers(recurse=False): for buffer in module.buffers(recurse=False):
buffer.data = buffer.data.to(device=torch.cuda.current_device()) buffer.data = buffer.data.to(device=torch.cuda.current_device())
buffer.data = cast_tensor_to_fp16(buffer.data) buffer.data = cast_tensor_to_fp16(buffer.data)
class ZeroContextMgr(metaclass=SingletonMeta):
current_context: Optional[ZeroInitContext] = None
@contextlib.contextmanager
def hijack_context_config(self, **kwargs):
if self.current_context is None:
yield
else:
old_config = self.current_context.config
self.current_context.config = ZeroContextConfig(**kwargs)
yield
self.current_context.config = old_config
def no_shard_zero_context():
return ZeroContextMgr().hijack_context_config(shard_param=False, rm_torch_payload_on_the_fly=False)
def no_shard_zero_decrator(init_func):
def _no_shard(*args, **kwargs):
with no_shard_zero_context():
init_func(*args, **kwargs)
return _no_shard

View File

@ -0,0 +1,97 @@
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.context import MOE_CONTEXT
from colossalai.nn.layer import MoeModule
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.testing import rerun_on_exception
from colossalai.utils import get_current_device
from tests.test_zero_data_parallel.common import CONFIG
class MoeModel(nn.Module):
def __init__(self):
super().__init__()
self.proj1 = nn.Linear(4, 8)
expert_cls = nn.Linear
expert_args_dict = dict(in_features=8, out_features=8)
self.moe = MoeModule(dim_model=8,
num_experts=8,
noisy_policy='Jitter',
use_residual=True,
expert_cls=expert_cls,
**expert_args_dict)
self.proj2 = nn.Linear(8, 4)
def forward(self, x):
x = self.proj(x)
x = self.moe(x)
x = self.proj2(x)
return x
@parameterize("init_device_type", ['cpu', 'cuda'])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_moe_zero_init(init_device_type, shard_strategy_class):
logger = get_dist_logger("test_moe_zero_init")
if init_device_type == 'cuda':
init_device = torch.device(f"cuda:{get_current_device()}")
elif init_device_type == 'cpu':
init_device = torch.device("cpu")
else:
raise NotImplementedError("Unknown device found.")
model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(target_device=init_device,
shard_strategy=shard_strategy_class(),
shard_param=True,
model_numel_tensor=model_numel_tensor,
rm_torch_payload_on_the_fly=False):
model = MoeModel()
for name, param in model.named_parameters():
assert hasattr(param, 'col_attr')
# the weights in the gate should be fp32
if 'gate' in name:
assert param.col_attr.sharded_data_tensor.dtype == torch.float32
else:
assert param.col_attr.sharded_data_tensor.dtype == torch.half
# the parameters in moe experts and its gate should not be sharded
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
assert not param.col_attr.sharded_data_tensor.is_sharded
else:
assert param.col_attr.sharded_data_tensor.is_sharded
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
run_moe_zero_init()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_moe_zero_init(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_zero_init(world_size=2)