[test] added a decorator for address already in use error with backward compatibility (#760)

* [test] added a decorator for address already in use error with backward compatibility

* [test] added a decorator for address already in use error with backward compatibility
pull/764/head
Frank Lee 2022-04-14 16:48:44 +08:00 committed by GitHub
parent 10ef8afdd2
commit 4ea49cb536
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 2 deletions

View File

@ -1,7 +1,7 @@
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
from .utils import parameterize, rerun_on_exception
from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use
__all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception'
'rerun_on_exception', 'rerun_if_address_is_in_use'
]

View File

@ -1,7 +1,9 @@
import re
import torch
from typing import Callable, List, Any
from functools import partial
from inspect import signature
from packaging import version
def parameterize(argument: str, values: List[Any]) -> Callable:
@ -144,3 +146,29 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
return _run_until_success
return _wrapper
def rerun_if_address_is_in_use():
"""
This function reruns a wrapped function if "address already in use" occurs
in testing spawned with torch.multiprocessing
Usage::
@rerun_if_address_is_in_use()
def test_something():
...
"""
# check version
torch_version = version.parse(torch.__version__)
assert torch_version.major == 1
# only torch >= 1.8 has ProcessRaisedException
if torch_version.minor >= 8:
exception = torch.multiprocessing.ProcessRaisedException
else:
exception = Exception
func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*")
return func_wrapper