|
|
@ -1,8 +1,13 @@ |
|
|
|
|
|
|
|
import gc |
|
|
|
|
|
|
|
import random |
|
|
|
import re |
|
|
|
import re |
|
|
|
import torch |
|
|
|
import socket |
|
|
|
from typing import Callable, List, Any |
|
|
|
|
|
|
|
from functools import partial |
|
|
|
from functools import partial |
|
|
|
from inspect import signature |
|
|
|
from inspect import signature |
|
|
|
|
|
|
|
from typing import Any, Callable, List |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
import torch.multiprocessing as mp |
|
|
|
from packaging import version |
|
|
|
from packaging import version |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -43,7 +48,7 @@ def parameterize(argument: str, values: List[Any]) -> Callable: |
|
|
|
# > davis: hello |
|
|
|
# > davis: hello |
|
|
|
# > davis: bye |
|
|
|
# > davis: bye |
|
|
|
# > davis: stop |
|
|
|
# > davis: stop |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
Args: |
|
|
|
argument (str): the name of the argument to parameterize |
|
|
|
argument (str): the name of the argument to parameterize |
|
|
|
values (List[Any]): a list of values to iterate for this argument |
|
|
|
values (List[Any]): a list of values to iterate for this argument |
|
|
@ -85,13 +90,13 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non |
|
|
|
def test_method(): |
|
|
|
def test_method(): |
|
|
|
print('hey') |
|
|
|
print('hey') |
|
|
|
raise RuntimeError('Address already in use') |
|
|
|
raise RuntimeError('Address already in use') |
|
|
|
|
|
|
|
|
|
|
|
# rerun for infinite times if Runtime error occurs |
|
|
|
# rerun for infinite times if Runtime error occurs |
|
|
|
@rerun_on_exception(exception_type=RuntimeError, max_try=None) |
|
|
|
@rerun_on_exception(exception_type=RuntimeError, max_try=None) |
|
|
|
def test_method(): |
|
|
|
def test_method(): |
|
|
|
print('hey') |
|
|
|
print('hey') |
|
|
|
raise RuntimeError('Address already in use') |
|
|
|
raise RuntimeError('Address already in use') |
|
|
|
|
|
|
|
|
|
|
|
# rerun only the exception message is matched with pattern |
|
|
|
# rerun only the exception message is matched with pattern |
|
|
|
# for infinite times if Runtime error occurs |
|
|
|
# for infinite times if Runtime error occurs |
|
|
|
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") |
|
|
|
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") |
|
|
@ -101,10 +106,10 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
Args: |
|
|
|
exception_type (Exception, Optional): The type of exception to detect for rerun |
|
|
|
exception_type (Exception, Optional): The type of exception to detect for rerun |
|
|
|
pattern (str, Optional): The pattern to match the exception message. |
|
|
|
pattern (str, Optional): The pattern to match the exception message. |
|
|
|
If the pattern is not None and matches the exception message, |
|
|
|
If the pattern is not None and matches the exception message, |
|
|
|
the exception will be detected for rerun |
|
|
|
the exception will be detected for rerun |
|
|
|
max_try (int, Optional): Maximum reruns for this function. The default value is 5. |
|
|
|
max_try (int, Optional): Maximum reruns for this function. The default value is 5. |
|
|
|
If max_try is None, it will rerun foreven if exception keeps occurings |
|
|
|
If max_try is None, it will rerun foreven if exception keeps occurings |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
@ -202,3 +207,72 @@ def skip_if_not_enough_gpus(min_gpus: int): |
|
|
|
return _execute_by_gpu_num |
|
|
|
return _execute_by_gpu_num |
|
|
|
|
|
|
|
|
|
|
|
return _wrap_func |
|
|
|
return _wrap_func |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def free_port() -> int: |
|
|
|
|
|
|
|
"""Get a free port on localhost. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
|
|
int: A free port on localhost. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
while True: |
|
|
|
|
|
|
|
port = random.randint(20000, 65000) |
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
with socket.socket() as sock: |
|
|
|
|
|
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
|
|
|
|
|
|
|
sock.bind(("localhost", port)) |
|
|
|
|
|
|
|
return port |
|
|
|
|
|
|
|
except OSError: |
|
|
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def spawn(func, nprocs=1, **kwargs): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
This function is used to spawn processes for testing. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Usage: |
|
|
|
|
|
|
|
# must contians arguments rank, world_size, port |
|
|
|
|
|
|
|
def do_something(rank, world_size, port): |
|
|
|
|
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spawn(do_something, nprocs=8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# can also pass other arguments |
|
|
|
|
|
|
|
def do_something(rank, world_size, port, arg1, arg2): |
|
|
|
|
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spawn(do_something, nprocs=8, arg1=1, arg2=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
|
|
func (Callable): The function to be spawned. |
|
|
|
|
|
|
|
nprocs (int, optional): The number of processes to spawn. Defaults to 1. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
port = free_port() |
|
|
|
|
|
|
|
wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs) |
|
|
|
|
|
|
|
mp.spawn(wrapped_func, nprocs=nprocs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_cache_before_run(): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
This function is a wrapper to clear CUDA and python cache before executing the function. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Usage: |
|
|
|
|
|
|
|
@clear_cache_before_run() |
|
|
|
|
|
|
|
def test_something(): |
|
|
|
|
|
|
|
... |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _wrap_func(f): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _clear_cache(*args, **kwargs): |
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
|
|
|
|
torch.cuda.reset_max_memory_allocated() |
|
|
|
|
|
|
|
torch.cuda.reset_max_memory_cached() |
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
gc.collect() |
|
|
|
|
|
|
|
f(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return _clear_cache |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return _wrap_func |
|
|
|