diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 865262cae..a42b550cd 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -104,30 +104,32 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf # Check whether the loaded model & optimizer works smoothly. model.train() new_model.train() + data_for_shard = data_gen_fn() + data_for_origin = data_gen_fn() if booster.plugin.stage_manager is not None: booster.execute_pipeline( - _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + _preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True, return_outputs=False ) booster.execute_pipeline( - _preprocess_data(data), new_model, _criterion, new_optimizer, return_loss=True, return_outputs=False + _preprocess_data(data_for_origin), + new_model, + _criterion, + new_optimizer, + return_loss=True, + return_outputs=False, ) else: - old_model_loss = criterion(model(**_preprocess_data(data))) + old_model_loss = criterion(model(**_preprocess_data(data_for_shard))) optimizer.backward(old_model_loss) - new_model_loss = criterion(new_model(**_preprocess_data(data))) + new_model_loss = criterion(new_model(**_preprocess_data(data_for_origin))) new_optimizer.backward(new_model_loss) optimizer.step() new_optimizer.step() # Check updated weights. - stage_manager = booster.plugin.stage_manager - - if stage_manager is None or stage_manager.is_first_stage(): - assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) - assert_close_loose( - model.unwrap().h[0].mlp.c_fc.weight.data, new_model.unwrap().h[0].mlp.c_fc.weight.data, atol=5e-3, rtol=5e-3 - ) + for p1, p2 in zip(model.unwrap().parameters(), new_model.unwrap().parameters()): + assert_close_loose(p1, p2, atol=5e-3, rtol=5e-3) dist.barrier() Randomizer.reset_index()