From 83a847d0586b6002fdd51e1d4be9c6f39fb45c2a Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 21 Mar 2022 15:51:57 +0800 Subject: [PATCH] [test] added rerun on exception for testing (#475) * [test] added rerun on exception function * polish code --- colossalai/testing/__init__.py | 7 +- colossalai/testing/utils.py | 127 ++++++++++++++++++++++++++------- 2 files changed, 107 insertions(+), 27 deletions(-) diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index 519b35e32..ec349de61 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -1,4 +1,7 @@ from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group -from .utils import parameterize +from .utils import parameterize, rerun_on_exception -__all__ = ['assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize'] +__all__ = [ + 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', + 'rerun_on_exception' +] diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 686281f4c..053ef3d94 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -1,8 +1,9 @@ -from typing import List, Any +import re +from typing import Callable, List, Any from functools import partial -def parameterize(argument: str, values: List[Any]): +def parameterize(argument: str, values: List[Any]) -> Callable: """ This function is to simulate the same behavior as pytest.mark.parameterize. As we want to avoid the number of distributed network initialization, we need to have @@ -11,42 +12,118 @@ def parameterize(argument: str, values: List[Any]): If a function is wrapped with this wrapper, non-paramterized arguments must be keyword arguments, positioanl arguments are not allowed. - Example 1: + Usgae:: - @parameterize('person', ['xavier', 'davis']) - def say_something(person, msg): - print(f'{person}: {msg}') + # Example 1: + @parameterize('person', ['xavier', 'davis']) + def say_something(person, msg): + print(f'{person}: {msg}') - say_something(msg='hello') + say_something(msg='hello') - This will generate output: - > xavier: hello - > davis: hello + # This will generate output: + # > xavier: hello + # > davis: hello + # Exampel 2: + @parameterize('person', ['xavier', 'davis']) + @parameterize('msg', ['hello', 'bye', 'stop']) + def say_something(person, msg): + print(f'{person}: {msg}') - Exampel 2: + say_something() - @parameterize('person', ['xavier', 'davis']) - @parameterize('msg', ['hello', 'bye', 'stop']) - def say_something(person, msg): - print(f'{person}: {msg}') - - say_something() - - This will generate output: - > xavier: hello - > xavier: bye - > xavier: stop - > davis: hello - > davis: bye - > davis: stop + # This will generate output: + # > xavier: hello + # > xavier: bye + # > xavier: stop + # > davis: hello + # > davis: bye + # > davis: stop + + Args: + argument (str): the name of the argument to parameterize + values (List[Any]): a list of values to iterate for this argument """ def _wrapper(func): + def _execute_function_by_param(**kwargs): for val in values: arg_map = {argument: val} partial_func = partial(func, **arg_map) partial_func(**kwargs) + return _execute_function_by_param + + return _wrapper + + +def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 5) -> Callable: + """ + A decorator on a function to re-run when an exception occurs. + + Usage:: + + # rerun for all kinds of exception + @rerun_on_exception() + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for RuntimeError only + @rerun_on_exception(exception_type=RuntimeError) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for maximum 10 times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=10) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=None) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun only the exception message is matched with pattern + # for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + Args: + exception_type (Exception, Optional): The type of exception to detect for rerun + pattern (str, Optional): The pattern to match the exception message. + If the pattern is not None and matches the exception message, + the exception will be detected for rerun + max_try (int, Optional): Maximum reruns for this function. The default value is 5. + If max_try is None, it will rerun foreven if exception keeps occurings + """ + + def _wrapper(func): + + def _run_until_success(*args, **kwargs): + try_count = 0 + assert max_try is None or isinstance(max_try, int), \ + f'Expected max_try to be None or int, but got {type(max_try)}' + + while max_try is None or try_count < max_try: + try: + try_count += 1 + func(*args, **kwargs) + except exception_type as e: + if pattern is None or re.match(pattern, str(e)): + # when pattern is not specified, we always skip the exception + # when pattern is specified, we only skip when pattern is matched + continue + else: + raise e + + return _run_until_success + return _wrapper