[Gemini] test step-tensor mapping using repeated_computed_layers.py (#2127)

pull/2129/head
Jiarui Fang 2022-12-13 16:34:10 +08:00 committed by GitHub
parent 8fac837679
commit deee317b0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 2 deletions

View File

@ -4,10 +4,15 @@ from . import (
hanging_param_model,
inline_op_model,
nested_model,
repeated_computed_layer,
repeated_computed_layers,
resnet,
simple_net,
)
from .utils import run_fwd_bwd
from . import albert # isort:skip
__all__ = [
'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet',
'simple_net', 'run_fwd_bwd', 'albert'
]

View File

@ -23,7 +23,7 @@ from tests.test_tensor.common_utils import set_seed
@parameterize('placement_policy', ['auto'])
@parameterize('keep_gather', [False])
@parameterize('model_name', ['bert', 'albert', 'gpt2'])
@parameterize('model_name', ['repeated_computed_layers', 'bert', 'albert', 'gpt2'])
@parameterize('use_grad_checkpoint', [False, True])
def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False):
set_seed(42)
@ -49,6 +49,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
print('runtime tracer: ', runtime_tracer_non_model_data)
print([memstats.param_used_timestep(p) for p in model.parameters()])
if model_name == 'repeated_computed_layers':
for idx, p in enumerate(model.parameters()):
step_list = memstats.param_used_timestep(p)
if idx < 4:
assert len(step_list) == 4
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000