[ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)

* fix ci

fix

* fix test

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

* fix

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
pull/5278/head^2
flybird11111 2024-01-17 13:38:55 +08:00 committed by GitHub
parent d69cd2eb89
commit 2a0558d8ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 13 additions and 11 deletions

View File

@ -104,30 +104,32 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
# Check whether the loaded model & optimizer works smoothly. # Check whether the loaded model & optimizer works smoothly.
model.train() model.train()
new_model.train() new_model.train()
data_for_shard = data_gen_fn()
data_for_origin = data_gen_fn()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
booster.execute_pipeline( booster.execute_pipeline(
_preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False _preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True, return_outputs=False
) )
booster.execute_pipeline( booster.execute_pipeline(
_preprocess_data(data), new_model, _criterion, new_optimizer, return_loss=True, return_outputs=False _preprocess_data(data_for_origin),
new_model,
_criterion,
new_optimizer,
return_loss=True,
return_outputs=False,
) )
else: else:
old_model_loss = criterion(model(**_preprocess_data(data))) old_model_loss = criterion(model(**_preprocess_data(data_for_shard)))
optimizer.backward(old_model_loss) optimizer.backward(old_model_loss)
new_model_loss = criterion(new_model(**_preprocess_data(data))) new_model_loss = criterion(new_model(**_preprocess_data(data_for_origin)))
new_optimizer.backward(new_model_loss) new_optimizer.backward(new_model_loss)
optimizer.step() optimizer.step()
new_optimizer.step() new_optimizer.step()
# Check updated weights. # Check updated weights.
stage_manager = booster.plugin.stage_manager for p1, p2 in zip(model.unwrap().parameters(), new_model.unwrap().parameters()):
assert_close_loose(p1, p2, atol=5e-3, rtol=5e-3)
if stage_manager is None or stage_manager.is_first_stage():
assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3)
assert_close_loose(
model.unwrap().h[0].mlp.c_fc.weight.data, new_model.unwrap().h[0].mlp.c_fc.weight.data, atol=5e-3, rtol=5e-3
)
dist.barrier() dist.barrier()
Randomizer.reset_index() Randomizer.reset_index()