[test] added a decorator for address already in use error with backward compatibility (#760)

* [test] added a decorator for address already in use error with backward compatibility

* [test] added a decorator for address already in use error with backward compatibility
pull/764/head
Frank Lee 2022-04-14 16:48:44 +08:00 committed by GitHub
parent 10ef8afdd2
commit 4ea49cb536
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 2 deletions

View File

@ -1,7 +1,7 @@
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
from .utils import parameterize, rerun_on_exception from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use
__all__ = [ __all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception' 'rerun_on_exception', 'rerun_if_address_is_in_use'
] ]

View File

@ -1,7 +1,9 @@
import re import re
import torch
from typing import Callable, List, Any from typing import Callable, List, Any
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from packaging import version
def parameterize(argument: str, values: List[Any]) -> Callable: def parameterize(argument: str, values: List[Any]) -> Callable:
@ -144,3 +146,29 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
return _run_until_success return _run_until_success
return _wrapper return _wrapper
def rerun_if_address_is_in_use():
"""
This function reruns a wrapped function if "address already in use" occurs
in testing spawned with torch.multiprocessing
Usage::
@rerun_if_address_is_in_use()
def test_something():
...
"""
# check version
torch_version = version.parse(torch.__version__)
assert torch_version.major == 1
# only torch >= 1.8 has ProcessRaisedException
if torch_version.minor >= 8:
exception = torch.multiprocessing.ProcessRaisedException
else:
exception = Exception
func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*")
return func_wrapper