from typing import List, Any from functools import partial def parameterize(argument: str, values: List[Any]): """ This function is to simulate the same behavior as pytest.mark.parameterize. As we want to avoid the number of distributed network initialization, we need to have this extra decorator on the function launched by torch.multiprocessing. If a function is wrapped with this wrapper, non-paramterized arguments must be keyword arguments, positioanl arguments are not allowed. Example 1: @parameterize('person', ['xavier', 'davis']) def say_something(person, msg): print(f'{person}: {msg}') say_something(msg='hello') This will generate output: > xavier: hello > davis: hello Exampel 2: @parameterize('person', ['xavier', 'davis']) @parameterize('msg', ['hello', 'bye', 'stop']) def say_something(person, msg): print(f'{person}: {msg}') say_something() This will generate output: > xavier: hello > xavier: bye > xavier: stop > davis: hello > davis: bye > davis: stop """ def _wrapper(func): def _execute_function_by_param(**kwargs): for val in values: arg_map = {argument: val} partial_func = partial(func, **arg_map) partial_func(**kwargs) return _execute_function_by_param return _wrapper