diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index ec349de61..d8eecbb09 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -1,7 +1,7 @@ from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group -from .utils import parameterize, rerun_on_exception +from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use __all__ = [ 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', - 'rerun_on_exception' + 'rerun_on_exception', 'rerun_if_address_is_in_use' ] diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 82b2bbeef..50566d2a8 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -1,7 +1,9 @@ import re +import torch from typing import Callable, List, Any from functools import partial from inspect import signature +from packaging import version def parameterize(argument: str, values: List[Any]) -> Callable: @@ -144,3 +146,29 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non return _run_until_success return _wrapper + + +def rerun_if_address_is_in_use(): + """ + This function reruns a wrapped function if "address already in use" occurs + in testing spawned with torch.multiprocessing + + Usage:: + + @rerun_if_address_is_in_use() + def test_something(): + ... + + """ + # check version + torch_version = version.parse(torch.__version__) + assert torch_version.major == 1 + + # only torch >= 1.8 has ProcessRaisedException + if torch_version.minor >= 8: + exception = torch.multiprocessing.ProcessRaisedException + else: + exception = Exception + + func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*") + return func_wrapper