mirror of https://github.com/hpcaitech/ColossalAI
[test] added a decorator for address already in use error with backward compatibility (#760)
* [test] added a decorator for address already in use error with backward compatibility * [test] added a decorator for address already in use error with backward compatibilitypull/764/head
parent
10ef8afdd2
commit
4ea49cb536
|
@ -1,7 +1,7 @@
|
||||||
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
|
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
|
||||||
from .utils import parameterize, rerun_on_exception
|
from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
|
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
|
||||||
'rerun_on_exception'
|
'rerun_on_exception', 'rerun_if_address_is_in_use'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
import re
|
import re
|
||||||
|
import torch
|
||||||
from typing import Callable, List, Any
|
from typing import Callable, List, Any
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
def parameterize(argument: str, values: List[Any]) -> Callable:
|
def parameterize(argument: str, values: List[Any]) -> Callable:
|
||||||
|
@ -144,3 +146,29 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
|
||||||
return _run_until_success
|
return _run_until_success
|
||||||
|
|
||||||
return _wrapper
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def rerun_if_address_is_in_use():
|
||||||
|
"""
|
||||||
|
This function reruns a wrapped function if "address already in use" occurs
|
||||||
|
in testing spawned with torch.multiprocessing
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_something():
|
||||||
|
...
|
||||||
|
|
||||||
|
"""
|
||||||
|
# check version
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
assert torch_version.major == 1
|
||||||
|
|
||||||
|
# only torch >= 1.8 has ProcessRaisedException
|
||||||
|
if torch_version.minor >= 8:
|
||||||
|
exception = torch.multiprocessing.ProcessRaisedException
|
||||||
|
else:
|
||||||
|
exception = Exception
|
||||||
|
|
||||||
|
func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*")
|
||||||
|
return func_wrapper
|
||||||
|
|
Loading…
Reference in New Issue