From a1fd8778288b0ab76d20fa39290c8fe62cd5e654 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Wed, 15 Nov 2023 14:40:06 +0800 Subject: [PATCH] fix(train.py): clear memory pool before optim step --- train.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 5ea91e8..789094a 100644 --- a/train.py +++ b/train.py @@ -220,7 +220,7 @@ def main(args): # start iterating the train data and begin training for batch_count in range(train_state.batch_count, total_steps): empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) - # torch.cuda.memory._record_memory_history() + torch.cuda.memory._record_memory_history() start_time = time.time() timer("one-batch").start() @@ -262,6 +262,9 @@ def main(args): ) timer("fwd-bwd").stop() + if gpc.fstp_handler is not None and gpc.fstp_handler.enable_memory_pool: + gpc.fstp_handler.clear_memory_pool() + # update parameters, and returns (success_update, grad_norm) trainer_result = trainer.step() assert trainer_result is not None @@ -324,9 +327,7 @@ def main(args): if batch_count % 2 == 0: prof.step() - if gpc.fstp_handler is not None and gpc.fstp_handler.enable_memory_pool: - gpc.fstp_handler.clear_memory_pool() - # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.reset_peak_memory_stats() ckpt_manager.wait_async_upload_finish() @@ -353,3 +354,5 @@ if __name__ == "__main__": mm.monitor_exception( alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() ) + + torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")