[gemini] fix init bugs for modules (#2047)

* [gemini] fix init bugs for modules

* fix bugs
pull/2051/head
HELSON 2 years ago committed by GitHub
parent 81e0da7fa8
commit f6178728a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -96,10 +96,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times? FIXME(fjr) The module may be passed to this function multiple times?
""" """
if hasattr(module, '_colo_visited'):
return
name_list = [] name_list = []
for name, param in _named_params_with_replica(module): for name, param in _named_params_with_replica(module):
if isinstance(param, ColoTensor): if isinstance(param, ColoTensor):
@ -130,7 +126,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
colo_param.shared_param_modules.append(submodule) colo_param.shared_param_modules.append(submodule)
module.to(self._device) module.to(self._device)
ColoModulize(module)
def post_process_colo_init_ctx(model: torch.nn.Module, def post_process_colo_init_ctx(model: torch.nn.Module,

@ -24,6 +24,11 @@ from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed from tests.test_tensor.common_utils import debug_print, set_seed
# this model is large enough to slice to chunks
TEST_MODELS = ['gpt2']
# these models are too small, all parameters in these models are compacted into one chunk
EXAMPLE_MODELS = ['hanging_param_model', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
def check_param(model: ZeroDDP, torch_model: torch.nn.Module): def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
@ -40,10 +45,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2)
# 'gpt2', 'bert',
TEST_MODELS = ['hanging_param_model', 'gpt2', 'bert', 'simple_net', 'nested_model', 'repeated_computed_layers']
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', TEST_MODELS) @parameterize('model_name', TEST_MODELS)
def exam_model_step(placement_policy, model_name: str): def exam_model_step(placement_policy, model_name: str):
@ -61,8 +62,6 @@ def exam_model_step(placement_policy, model_name: str):
with ColoInitContext(device=init_dev): with ColoInitContext(device=init_dev):
model = model_builder() model = model_builder()
post_process_colo_init_ctx(model, device=init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
@ -102,8 +101,8 @@ def exam_model_step(placement_policy, model_name: str):
check_param(model, torch_model) check_param(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', TEST_MODELS) @parameterize('model_name', EXAMPLE_MODELS)
def exam_tiny_example(placement_policy, model_name: str): def exam_tiny_example(placement_policy, model_name: str):
set_seed(2008) set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
@ -119,8 +118,6 @@ def exam_tiny_example(placement_policy, model_name: str):
with ColoInitContext(device=init_dev): with ColoInitContext(device=init_dev):
model = model_builder() model = model_builder()
post_process_colo_init_ctx(model, device=init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)

Loading…
Cancel
Save