mirror of https://github.com/InternLM/InternLM
Feat(PythonGC): Do garbage collection manually (#326)
* feat:add gc control * feat:add gc control * feat:add gc control * feat:add gc * re-lintpull/338/head^2
parent
3b0eff0c8a
commit
f5337f6e02
|
@ -2,6 +2,7 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
@ -446,6 +447,8 @@ def initialize_distributed_env(
|
||||||
master_port (str): The master port for distributed training. 8888 by default.
|
master_port (str): The master port for distributed training. 8888 by default.
|
||||||
seed (int, optional): Specified random seed for every process. 1024 by default.
|
seed (int, optional): Specified random seed for every process. 1024 by default.
|
||||||
"""
|
"""
|
||||||
|
# close automatic garbage collection
|
||||||
|
gc.disable()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import gc
|
||||||
import math
|
import math
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
|
@ -41,6 +42,8 @@ def empty_cache_and_diag(batch_count, interval=50):
|
||||||
bench_net()
|
bench_net()
|
||||||
# do empty_cache after the bench
|
# do empty_cache after the bench
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
# do garbage collection
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
def benchmark_forward(
|
def benchmark_forward(
|
||||||
|
|
Loading…
Reference in New Issue