mirror of https://github.com/hpcaitech/ColossalAI
[gemini] fix init bugs for modules (#2047)
* [gemini] fix init bugs for modules * fix bugspull/2051/head
parent
81e0da7fa8
commit
f6178728a0
|
@ -96,10 +96,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
The function to call at the end of the constructor of each module.
|
The function to call at the end of the constructor of each module.
|
||||||
FIXME(fjr) The module may be passed to this function multiple times?
|
FIXME(fjr) The module may be passed to this function multiple times?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if hasattr(module, '_colo_visited'):
|
|
||||||
return
|
|
||||||
|
|
||||||
name_list = []
|
name_list = []
|
||||||
for name, param in _named_params_with_replica(module):
|
for name, param in _named_params_with_replica(module):
|
||||||
if isinstance(param, ColoTensor):
|
if isinstance(param, ColoTensor):
|
||||||
|
@ -130,7 +126,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
colo_param.shared_param_modules.append(submodule)
|
colo_param.shared_param_modules.append(submodule)
|
||||||
|
|
||||||
module.to(self._device)
|
module.to(self._device)
|
||||||
ColoModulize(module)
|
|
||||||
|
|
||||||
|
|
||||||
def post_process_colo_init_ctx(model: torch.nn.Module,
|
def post_process_colo_init_ctx(model: torch.nn.Module,
|
||||||
|
|
|
@ -24,6 +24,11 @@ from tests.components_to_test import run_fwd_bwd
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||||
|
|
||||||
|
# this model is large enough to slice to chunks
|
||||||
|
TEST_MODELS = ['gpt2']
|
||||||
|
# these models are too small, all parameters in these models are compacted into one chunk
|
||||||
|
EXAMPLE_MODELS = ['hanging_param_model', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
|
||||||
|
|
||||||
|
|
||||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||||
zero_dict = model.state_dict(only_rank_0=False)
|
zero_dict = model.state_dict(only_rank_0=False)
|
||||||
|
@ -40,10 +45,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2)
|
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
# 'gpt2', 'bert',
|
|
||||||
TEST_MODELS = ['hanging_param_model', 'gpt2', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||||
@parameterize('model_name', TEST_MODELS)
|
@parameterize('model_name', TEST_MODELS)
|
||||||
def exam_model_step(placement_policy, model_name: str):
|
def exam_model_step(placement_policy, model_name: str):
|
||||||
|
@ -61,8 +62,6 @@ def exam_model_step(placement_policy, model_name: str):
|
||||||
with ColoInitContext(device=init_dev):
|
with ColoInitContext(device=init_dev):
|
||||||
model = model_builder()
|
model = model_builder()
|
||||||
|
|
||||||
post_process_colo_init_ctx(model, device=init_dev)
|
|
||||||
|
|
||||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||||
p.data.copy_(torch_p.data)
|
p.data.copy_(torch_p.data)
|
||||||
|
|
||||||
|
@ -102,8 +101,8 @@ def exam_model_step(placement_policy, model_name: str):
|
||||||
check_param(model, torch_model)
|
check_param(model, torch_model)
|
||||||
|
|
||||||
|
|
||||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||||
@parameterize('model_name', TEST_MODELS)
|
@parameterize('model_name', EXAMPLE_MODELS)
|
||||||
def exam_tiny_example(placement_policy, model_name: str):
|
def exam_tiny_example(placement_policy, model_name: str):
|
||||||
set_seed(2008)
|
set_seed(2008)
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
|
@ -119,8 +118,6 @@ def exam_tiny_example(placement_policy, model_name: str):
|
||||||
with ColoInitContext(device=init_dev):
|
with ColoInitContext(device=init_dev):
|
||||||
model = model_builder()
|
model = model_builder()
|
||||||
|
|
||||||
post_process_colo_init_ctx(model, device=init_dev)
|
|
||||||
|
|
||||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||||
p.data.copy_(torch_p.data)
|
p.data.copy_(torch_p.data)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue