mirror of https://github.com/InternLM/InternLM
fix(train.py): clear memory pool before optim step
parent
3c07423151
commit
a1fd877828
11
train.py
11
train.py
|
@ -220,7 +220,7 @@ def main(args):
|
|||
# start iterating the train data and begin training
|
||||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
||||
# torch.cuda.memory._record_memory_history()
|
||||
torch.cuda.memory._record_memory_history()
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
|
||||
|
@ -262,6 +262,9 @@ def main(args):
|
|||
)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
if gpc.fstp_handler is not None and gpc.fstp_handler.enable_memory_pool:
|
||||
gpc.fstp_handler.clear_memory_pool()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
trainer_result = trainer.step()
|
||||
assert trainer_result is not None
|
||||
|
@ -324,9 +327,7 @@ def main(args):
|
|||
if batch_count % 2 == 0:
|
||||
prof.step()
|
||||
|
||||
if gpc.fstp_handler is not None and gpc.fstp_handler.enable_memory_pool:
|
||||
gpc.fstp_handler.clear_memory_pool()
|
||||
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
ckpt_manager.wait_async_upload_finish()
|
||||
|
@ -353,3 +354,5 @@ if __name__ == "__main__":
|
|||
mm.monitor_exception(
|
||||
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
|
||||
)
|
||||
|
||||
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
|
|
Loading…
Reference in New Issue