You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/testing/utils.py

53 lines
1.4 KiB

from typing import List, Any
from functools import partial
def parameterize(argument: str, values: List[Any]):
"""
This function is to simulate the same behavior as pytest.mark.parameterize. As
we want to avoid the number of distributed network initialization, we need to have
this extra decorator on the function launched by torch.multiprocessing.
If a function is wrapped with this wrapper, non-paramterized arguments must be keyword arguments,
positioanl arguments are not allowed.
Example 1:
@parameterize('person', ['xavier', 'davis'])
def say_something(person, msg):
print(f'{person}: {msg}')
say_something(msg='hello')
This will generate output:
> xavier: hello
> davis: hello
Exampel 2:
@parameterize('person', ['xavier', 'davis'])
@parameterize('msg', ['hello', 'bye', 'stop'])
def say_something(person, msg):
print(f'{person}: {msg}')
say_something()
This will generate output:
> xavier: hello
> xavier: bye
> xavier: stop
> davis: hello
> davis: bye
> davis: stop
"""
def _wrapper(func):
def _execute_function_by_param(**kwargs):
for val in values:
arg_map = {argument: val}
partial_func = partial(func, **arg_map)
partial_func(**kwargs)
return _execute_function_by_param
return _wrapper