diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py new file mode 100644 index 000000000..519b35e32 --- /dev/null +++ b/colossalai/testing/__init__.py @@ -0,0 +1,4 @@ +from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group +from .utils import parameterize + +__all__ = ['assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize'] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py new file mode 100644 index 000000000..052e564e9 --- /dev/null +++ b/colossalai/testing/comparison.py @@ -0,0 +1,29 @@ +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + + +def assert_equal(a: Tensor, b: Tensor): + assert torch.all(a == b), f'expected a and b to be equal but they are not, {a} vs {b}' + +def assert_not_equal(a: Tensor, b: Tensor): + assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}' + +def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8): + assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}' + +def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-2, atol: float = 1e-3): + assert_close(a, b, rtol, atol) + +def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): + # all gather tensors from different ranks + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group=process_group) + + # check if they are equal one by one + for i in range(world_size - 1): + a = tensor_list[i] + b = tensor_list[i+1] + assert torch.all(a == b), f'expected tensors on rank {i} and {i+1} to be equal but they are not, {a} vs {b}' diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py new file mode 100644 index 000000000..686281f4c --- /dev/null +++ b/colossalai/testing/utils.py @@ -0,0 +1,52 @@ +from typing import List, Any +from functools import partial + + +def parameterize(argument: str, values: List[Any]): + """ + This function is to simulate the same behavior as pytest.mark.parameterize. As + we want to avoid the number of distributed network initialization, we need to have + this extra decorator on the function launched by torch.multiprocessing. + + If a function is wrapped with this wrapper, non-paramterized arguments must be keyword arguments, + positioanl arguments are not allowed. + + Example 1: + + @parameterize('person', ['xavier', 'davis']) + def say_something(person, msg): + print(f'{person}: {msg}') + + say_something(msg='hello') + + This will generate output: + > xavier: hello + > davis: hello + + + Exampel 2: + + @parameterize('person', ['xavier', 'davis']) + @parameterize('msg', ['hello', 'bye', 'stop']) + def say_something(person, msg): + print(f'{person}: {msg}') + + say_something() + + This will generate output: + > xavier: hello + > xavier: bye + > xavier: stop + > davis: hello + > davis: bye + > davis: stop + """ + + def _wrapper(func): + def _execute_function_by_param(**kwargs): + for val in values: + arg_map = {argument: val} + partial_func = partial(func, **arg_map) + partial_func(**kwargs) + return _execute_function_by_param + return _wrapper