mirror of https://github.com/hpcaitech/ColossalAI
[test] added a decorator for address already in use error with backward compatibility (#760)
* [test] added a decorator for address already in use error with backward compatibility * [test] added a decorator for address already in use error with backward compatibilitypull/764/head
parent
10ef8afdd2
commit
4ea49cb536
|
@ -1,7 +1,7 @@
|
|||
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
|
||||
from .utils import parameterize, rerun_on_exception
|
||||
from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use
|
||||
|
||||
__all__ = [
|
||||
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
|
||||
'rerun_on_exception'
|
||||
'rerun_on_exception', 'rerun_if_address_is_in_use'
|
||||
]
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import re
|
||||
import torch
|
||||
from typing import Callable, List, Any
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from packaging import version
|
||||
|
||||
|
||||
def parameterize(argument: str, values: List[Any]) -> Callable:
|
||||
|
@ -144,3 +146,29 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
|
|||
return _run_until_success
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def rerun_if_address_is_in_use():
|
||||
"""
|
||||
This function reruns a wrapped function if "address already in use" occurs
|
||||
in testing spawned with torch.multiprocessing
|
||||
|
||||
Usage::
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_something():
|
||||
...
|
||||
|
||||
"""
|
||||
# check version
|
||||
torch_version = version.parse(torch.__version__)
|
||||
assert torch_version.major == 1
|
||||
|
||||
# only torch >= 1.8 has ProcessRaisedException
|
||||
if torch_version.minor >= 8:
|
||||
exception = torch.multiprocessing.ProcessRaisedException
|
||||
else:
|
||||
exception = Exception
|
||||
|
||||
func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*")
|
||||
return func_wrapper
|
||||
|
|
Loading…
Reference in New Issue