diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index 3293de7de..63fa2740f 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -40,7 +40,7 @@ def get_training_components(): num_layer = 2 vocab_size = 32 - def bert_model_builder(checkpoint): + def bert_model_builder(checkpoint: bool = False): config = BertConfig(vocab_size=vocab_size, gradient_checkpointing=checkpoint, hidden_size=hidden_dim, diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index eec1db6e7..ec6299a3c 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -18,8 +18,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal +from tests.test_tensor.common_utils import set_seed def check_param(model: ZeroDDP, torch_model: torch.nn.Module): @@ -37,19 +38,16 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key) -def run_fwd_bwd(model, criterion, optimizer, input_ids): - optimizer.zero_grad() - logits = model(input_ids) - logits = logits.float() - loss = criterion(logits, input_ids) - optimizer.backward(loss) - return logits +# 'gpt2', 'bert', +TEST_MODELS = ['gpt2', 'bert'] +# TEST_MODELS = ['simple_net'] @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -def exam_gpt_fwd_bwd(placement_policy): +@parameterize('model_name', TEST_MODELS) +def exam_model_step(placement_policy, model_name: str): set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() with ColoInitContext(device=get_current_device()): @@ -87,9 +85,13 @@ def exam_gpt_fwd_bwd(placement_policy): if i > 2: break - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids) - assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False) + loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True) + + assert torch.allclose(torch_loss, loss, rtol=1e-3, atol=1e-2), f"{torch_loss} vs {loss}" # debug_print([0], zero_logits, torch_logits) zero_optim.step() @@ -99,9 +101,10 @@ def exam_gpt_fwd_bwd(placement_policy): @parameterize('placement_policy', ['cuda', 'cpu']) -def exam_tiny_example(placement_policy): +@parameterize('model_name', TEST_MODELS) +def exam_tiny_example(placement_policy, model_name: str): set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() with ColoInitContext(device=get_current_device()): @@ -131,9 +134,13 @@ def exam_tiny_example(placement_policy): if i > 2: break - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids) - assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False) + loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True) + + assert torch.allclose(torch_loss, loss, rtol=1e-3, atol=1e-2), f"{torch_loss} vs {loss}" # debug_print([0], zero_logits, torch_logits) zero_optim.step() @@ -145,17 +152,17 @@ def exam_tiny_example(placement_policy): def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_gpt_fwd_bwd() + exam_model_step() exam_tiny_example() @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() -def test_gpt(world_size): +def test_optim(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_gpt(2) + test_optim(2) diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py index ea2783fb8..7b0c6e37a 100644 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -19,9 +19,10 @@ from tests.test_tensor.common_utils import debug_print, set_seed @parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @parameterize('keep_gathered', [True, False]) -def exam_state_dict(placement_policy, keep_gathered): +@parameterize('model_name', ['gpt2', 'bert']) +def exam_state_dict(placement_policy, keep_gathered, model_name: str): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() with ColoInitContext(device=get_current_device()): @@ -53,9 +54,10 @@ def exam_state_dict(placement_policy, keep_gathered): @parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @parameterize('keep_gathered', [True, False]) -def exam_load_state_dict(placement_policy, keep_gathered): +@parameterize('model_name', ['gpt2', 'bert']) +def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() with ColoInitContext(device=get_current_device()):