fix(train.py): clear memory pool before optim step

pull/436/head
huangting4201 2023-11-15 14:40:06 +08:00
parent 3c07423151
commit a1fd877828
1 changed files with 7 additions and 4 deletions

View File

@ -220,7 +220,7 @@ def main(args):
# start iterating the train data and begin training # start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps): for batch_count in range(train_state.batch_count, total_steps):
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
# torch.cuda.memory._record_memory_history() torch.cuda.memory._record_memory_history()
start_time = time.time() start_time = time.time()
timer("one-batch").start() timer("one-batch").start()
@ -262,6 +262,9 @@ def main(args):
) )
timer("fwd-bwd").stop() timer("fwd-bwd").stop()
if gpc.fstp_handler is not None and gpc.fstp_handler.enable_memory_pool:
gpc.fstp_handler.clear_memory_pool()
# update parameters, and returns (success_update, grad_norm) # update parameters, and returns (success_update, grad_norm)
trainer_result = trainer.step() trainer_result = trainer.step()
assert trainer_result is not None assert trainer_result is not None
@ -324,9 +327,7 @@ def main(args):
if batch_count % 2 == 0: if batch_count % 2 == 0:
prof.step() prof.step()
if gpc.fstp_handler is not None and gpc.fstp_handler.enable_memory_pool: torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
gpc.fstp_handler.clear_memory_pool()
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
ckpt_manager.wait_async_upload_finish() ckpt_manager.wait_async_upload_finish()
@ -353,3 +354,5 @@ if __name__ == "__main__":
mm.monitor_exception( mm.monitor_exception(
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
) )
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")