mirror of https://github.com/hpcaitech/ColossalAI
[test] fix gemini checkpoint and gpt test (#4620)
parent
e71d245293
commit
bd18678478
|
@ -32,7 +32,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
|
||||||
elif plugin_type == 'zero':
|
elif plugin_type == 'zero':
|
||||||
plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32)
|
plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32)
|
||||||
elif plugin_type == 'gemini':
|
elif plugin_type == 'gemini':
|
||||||
plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32)
|
plugin = GeminiPlugin(precision="fp16", initial_scale=32)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.")
|
raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.")
|
||||||
|
|
||||||
|
|
|
@ -102,7 +102,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="This test will hang in CI")
|
|
||||||
@parameterize('test_config', [{
|
@parameterize('test_config', [{
|
||||||
'tp_size': 2,
|
'tp_size': 2,
|
||||||
'pp_size': 2,
|
'pp_size': 2,
|
||||||
|
@ -220,7 +219,7 @@ def check_gpt2_3d(rank, world_size, port):
|
||||||
run_gpt2_3d_test()
|
run_gpt2_3d_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="This test will hang in CI")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
Loading…
Reference in New Issue