mirror of https://github.com/hpcaitech/ColossalAI
[ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)
* fix ci fix * fix test * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests * fix --------- Co-authored-by: Wenhao Chen <cwher@outlook.com>pull/5278/head^2
parent
d69cd2eb89
commit
2a0558d8ec
|
@ -104,30 +104,32 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||
# Check whether the loaded model & optimizer works smoothly.
|
||||
model.train()
|
||||
new_model.train()
|
||||
data_for_shard = data_gen_fn()
|
||||
data_for_origin = data_gen_fn()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
booster.execute_pipeline(
|
||||
_preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False
|
||||
_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True, return_outputs=False
|
||||
)
|
||||
booster.execute_pipeline(
|
||||
_preprocess_data(data), new_model, _criterion, new_optimizer, return_loss=True, return_outputs=False
|
||||
_preprocess_data(data_for_origin),
|
||||
new_model,
|
||||
_criterion,
|
||||
new_optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=False,
|
||||
)
|
||||
else:
|
||||
old_model_loss = criterion(model(**_preprocess_data(data)))
|
||||
old_model_loss = criterion(model(**_preprocess_data(data_for_shard)))
|
||||
optimizer.backward(old_model_loss)
|
||||
new_model_loss = criterion(new_model(**_preprocess_data(data)))
|
||||
new_model_loss = criterion(new_model(**_preprocess_data(data_for_origin)))
|
||||
new_optimizer.backward(new_model_loss)
|
||||
|
||||
optimizer.step()
|
||||
new_optimizer.step()
|
||||
|
||||
# Check updated weights.
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3)
|
||||
assert_close_loose(
|
||||
model.unwrap().h[0].mlp.c_fc.weight.data, new_model.unwrap().h[0].mlp.c_fc.weight.data, atol=5e-3, rtol=5e-3
|
||||
)
|
||||
for p1, p2 in zip(model.unwrap().parameters(), new_model.unwrap().parameters()):
|
||||
assert_close_loose(p1, p2, atol=5e-3, rtol=5e-3)
|
||||
|
||||
dist.barrier()
|
||||
Randomizer.reset_index()
|
||||
|
|
Loading…
Reference in New Issue