mirror of https://github.com/InternLM/InternLM
fix(train.py): clear memory pool before optim step
parent
3c07423151
commit
a1fd877828
11
train.py
11
train.py
|
@ -220,7 +220,7 @@ def main(args):
|
||||||
# start iterating the train data and begin training
|
# start iterating the train data and begin training
|
||||||
for batch_count in range(train_state.batch_count, total_steps):
|
for batch_count in range(train_state.batch_count, total_steps):
|
||||||
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
||||||
# torch.cuda.memory._record_memory_history()
|
torch.cuda.memory._record_memory_history()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
timer("one-batch").start()
|
timer("one-batch").start()
|
||||||
|
|
||||||
|
@ -262,6 +262,9 @@ def main(args):
|
||||||
)
|
)
|
||||||
timer("fwd-bwd").stop()
|
timer("fwd-bwd").stop()
|
||||||
|
|
||||||
|
if gpc.fstp_handler is not None and gpc.fstp_handler.enable_memory_pool:
|
||||||
|
gpc.fstp_handler.clear_memory_pool()
|
||||||
|
|
||||||
# update parameters, and returns (success_update, grad_norm)
|
# update parameters, and returns (success_update, grad_norm)
|
||||||
trainer_result = trainer.step()
|
trainer_result = trainer.step()
|
||||||
assert trainer_result is not None
|
assert trainer_result is not None
|
||||||
|
@ -324,9 +327,7 @@ def main(args):
|
||||||
if batch_count % 2 == 0:
|
if batch_count % 2 == 0:
|
||||||
prof.step()
|
prof.step()
|
||||||
|
|
||||||
if gpc.fstp_handler is not None and gpc.fstp_handler.enable_memory_pool:
|
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||||
gpc.fstp_handler.clear_memory_pool()
|
|
||||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
ckpt_manager.wait_async_upload_finish()
|
ckpt_manager.wait_async_upload_finish()
|
||||||
|
@ -353,3 +354,5 @@ if __name__ == "__main__":
|
||||||
mm.monitor_exception(
|
mm.monitor_exception(
|
||||||
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
|
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||||
|
|
Loading…
Reference in New Issue