[hotfix] hotfix Gemini for no leaf modules bug (#2043)

pull/2045/head
Jiarui Fang 2022-11-30 14:53:41 +08:00 committed by GitHub
parent 384cd26314
commit 31c644027b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 28 deletions

View File

@ -1,10 +1,10 @@
from typing import Dict, Iterator, Optional, Tuple, Union from typing import Any, Dict, Iterator, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup, ShardSpec from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
@ -26,6 +26,34 @@ def _named_params_with_replica(
yield name, val yield name, val
def _convert_to_coloparam(param: torch.nn.Parameter,
device: torch.device,
dtype=torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec: Optional[Any] = None) -> ColoParameter:
if isinstance(param, ColoParameter):
return param
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
# param is the global tensor.
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
# the param that can not be sharded by the default plan
if default_pg is not None:
colo_param.set_process_group(default_pg)
if default_dist_spec is not None:
try:
colo_param.set_dist_spec(default_dist_spec)
except:
pass
return colo_param
def ColoModulize(module): def ColoModulize(module):
""" """
Replacing the parameters() and named_parameters() with our customized ones Replacing the parameters() and named_parameters() with our customized ones
@ -94,26 +122,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
if param in replaced_tensors: if param in replaced_tensors:
colo_param = replaced_tensors[param] colo_param = replaced_tensors[param]
else: else:
# detaching tensor is necessary for optimizers. colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg,
requires_grad = param.requires_grad self._default_dist_spec)
# param is the global tensor.
colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype),
requires_grad=requires_grad)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
# the param that can not be sharded by the default plan
if self._default_pg is not None:
colo_param.set_process_group(self._default_pg)
if self._default_dist_spec is not None:
try:
colo_param.set_dist_spec(self._default_dist_spec)
except:
pass
replaced_tensors[param] = colo_param replaced_tensors[param] = colo_param
delattr(submodule, param_name) delattr(submodule, param_name)
setattr(submodule, param_name, colo_param) setattr(submodule, param_name, colo_param)
@ -121,3 +131,39 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
module.to(self._device) module.to(self._device)
ColoModulize(module) ColoModulize(module)
def post_process_colo_init_ctx(model: torch.nn.Module,
device: torch.device = torch.device('cpu'),
dtype: torch.dtype = torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec=None):
"""post_process_colo_init_ctx
This function is called after `ColoInitContext`.
Args:
model (torch.nn.module): the model
device (torch.device, optional): device type of the model params. Defaults to torch.device('cpu').
dtype (torch.dtype, optional): dtype of the model params. Defaults to torch.float.
default_pg (Optional[ProcessGroup], optional): default process group. Defaults to None. Inidicates a DP-only process group.
default_dist_spec (Any, optional): default dist spec of params. Defaults to None.
Raises:
RuntimeError: raise error if
"""
torch_params = []
for n, p in model.named_parameters():
if not isinstance(p, ColoParameter):
print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter")
torch_params.append((n, p))
for (n, param) in torch_params:
delattr(model, n)
setattr(model, n, _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec))
del torch_params
for n, p in model.named_parameters():
if not isinstance(p, ColoTensor):
raise RuntimeError

View File

@ -15,10 +15,11 @@ from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
from tests.components_to_test import run_fwd_bwd from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed from tests.test_tensor.common_utils import debug_print, set_seed
@ -40,8 +41,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
# 'gpt2', 'bert', # 'gpt2', 'bert',
TEST_MODELS = ['gpt2', 'bert'] TEST_MODELS = ['no_leaf_module', 'gpt2', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
EXAMPLE_MODELS = ['simple_net']
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@ -57,8 +57,12 @@ def exam_model_step(placement_policy, model_name: str):
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
with ColoInitContext(device=get_current_device()): init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder() model = model_builder()
post_process_colo_init_ctx(model, device=init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
@ -99,7 +103,7 @@ def exam_model_step(placement_policy, model_name: str):
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', EXAMPLE_MODELS) @parameterize('model_name', TEST_MODELS)
def exam_tiny_example(placement_policy, model_name: str): def exam_tiny_example(placement_policy, model_name: str):
set_seed(2008) set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
@ -111,8 +115,12 @@ def exam_tiny_example(placement_policy, model_name: str):
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
with ColoInitContext(device=get_current_device()): init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder() model = model_builder()
post_process_colo_init_ctx(model, device=init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)