From b72b8445c69f56f6522791585897c595dc89df7d Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 17 Mar 2022 14:40:52 +0800 Subject: [PATCH] optimized context test time consumption (#446) --- colossalai/context/parallel_context.py | 1 + colossalai/testing/comparison.py | 2 +- colossalai/utils/common.py | 1 + tests/test_amp/test_naive_fp16.py | 7 +- tests/test_context/test_2d_init.py | 105 ------------- tests/test_context/test_2p5d_init.py | 128 ---------------- tests/test_context/test_3d_init.py | 120 --------------- tests/test_context/test_hybrid_parallel.py | 162 +++++++++++++++++++++ 8 files changed, 169 insertions(+), 357 deletions(-) delete mode 100644 tests/test_context/test_2d_init.py delete mode 100644 tests/test_context/test_2p5d_init.py delete mode 100644 tests/test_context/test_3d_init.py create mode 100644 tests/test_context/test_hybrid_parallel.py diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index 379497b48..376c582ab 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -449,6 +449,7 @@ class ParallelContext: dist.destroy_process_group(group) # destroy global process group dist.destroy_process_group() + self._groups.clear() def set_device(self, device_ordinal: int = None): """Sets distributed processes to be bound to devices. diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 052e564e9..e98f3c18b 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -13,7 +13,7 @@ def assert_not_equal(a: Tensor, b: Tensor): def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8): assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}' -def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-2, atol: float = 1e-3): +def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): assert_close(a, b, rtol, atol) def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 6427e4c8a..c9fe7d4d5 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -46,6 +46,7 @@ def free_port(): while True: try: sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) port = random.randint(20000, 65000) sock.bind(('localhost', port)) sock.close() diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py index d3de7c851..c6805ad51 100644 --- a/tests/test_amp/test_naive_fp16.py +++ b/tests/test_amp/test_naive_fp16.py @@ -5,6 +5,7 @@ import pytest import torch.multiprocessing as mp from colossalai.amp import convert_to_naive_amp from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.testing import assert_close_loose from colossalai.utils import free_port from functools import partial @@ -48,7 +49,7 @@ def run_naive_amp(): # forward pass amp_output = amp_model(data) torch_output = torch_model(data) - assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}' + assert_close_loose(amp_output, torch_output) # backward amp_optimizer.backward(amp_output.mean()) @@ -56,7 +57,7 @@ def run_naive_amp(): # check grad for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()): - torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3) + assert_close_loose(amp_param.grad, torch_param.grad.half()) # step amp_optimizer.step() @@ -64,7 +65,7 @@ def run_naive_amp(): # check updated param for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()): - torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3) + assert_close_loose(amp_param, torch_param.half()) def run_dist(rank, world_size, port): diff --git a/tests/test_context/test_2d_init.py b/tests/test_context/test_2d_init.py deleted file mode 100644 index 117b6e0d6..000000000 --- a/tests/test_context/test_2d_init.py +++ /dev/null @@ -1,105 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial -from pathlib import Path - -import pytest -import torch -import torch.multiprocessing as mp -from colossalai import launch -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import free_port - -CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute() - - -def check_data_parallel_rank(rank): - if rank in [0, 1, 2, 3, 4, 5, 6, 7]: - assert gpc.get_local_rank(ParallelMode.DATA) == 0 - elif rank in [8, 9, 10, 11, 12, 13, 14, 15]: - assert gpc.get_local_rank(ParallelMode.DATA) == 1 - - -def check_pipeline_parallel_rank(rank): - if rank in [0, 1, 2, 3]: - assert gpc.get_local_rank(ParallelMode.PIPELINE) == 0 - elif rank in [4, 5, 6, 7]: - assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1 - elif rank in [8, 9, 10, 11]: - assert gpc.get_local_rank(ParallelMode.PIPELINE) == 0 - elif rank in [12, 13, 14, 15]: - assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1 - - -def check_model_parallel_rank(rank): - for i in range(8): - if rank in [i, i+8]: - assert gpc.get_local_rank(ParallelMode.MODEL) == i - - -def check_tensor_parallel_rank(rank): - if rank in [0, 4, 8, 12]: - assert gpc.get_local_rank(ParallelMode.TENSOR) == 0 - elif rank in [1, 5, 9, 13]: - assert gpc.get_local_rank(ParallelMode.TENSOR) == 1 - elif rank in [2, 6, 10, 14]: - assert gpc.get_local_rank(ParallelMode.TENSOR) == 2 - elif rank in [3, 7, 11, 15]: - assert gpc.get_local_rank(ParallelMode.TENSOR) == 3 - - -def check_2d_parallel_rank(rank): - if rank in [0, 4, 8, 12]: - assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0 - assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 0 - elif rank in [1, 5, 9, 13]: - assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0 - assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 1 - elif rank in [2, 6, 10, 14]: - assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 1 - assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 0 - elif rank in [3, 7, 11, 15]: - assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 1 - assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 1 - - -def init_2d(rank, world_size, backend, port, host): - dist_args = dict( - config=CONFIG_PATH, - rank=rank, - world_size=world_size, - backend=backend, - port=port, - host=host, - verbose=True - ) - launch(**dist_args) - - check_tensor_parallel_rank(rank) - check_data_parallel_rank(rank) - check_2d_parallel_rank(rank) - check_pipeline_parallel_rank(rank) - check_model_parallel_rank(rank) - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.cpu -def test_2d_init(): - """ - As no computation or communication is done, we can run this test on CPU. - """ - world_size = 16 - test_fn = partial(init_2d, - world_size=world_size, - backend='gloo', - port=free_port(), - host='localhost' - ) - mp.spawn(test_fn, nprocs=world_size) - - -if __name__ == '__main__': - test_2d_init() diff --git a/tests/test_context/test_2p5d_init.py b/tests/test_context/test_2p5d_init.py deleted file mode 100644 index ef6789710..000000000 --- a/tests/test_context/test_2p5d_init.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial -from pathlib import Path - -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.utils import free_port - -CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute() - - -def check_data_parallel_rank(rank): - dp_rank = gpc.get_local_rank(ParallelMode.DATA) - - if rank in list(range(16)): - assert dp_rank == 0 - elif rank in list(range(16, 32)): - assert dp_rank == 1 - - -def check_pipeline_parallel_rank(rank): - ppr = gpc.get_local_rank(ParallelMode.PIPELINE) - - if rank in list(range(8)): - assert ppr == 0 - elif rank in list(range(8, 16)): - assert ppr == 1 - elif rank in list(range(16, 24)): - assert ppr == 0 - elif rank in list(range(24, 32)): - assert ppr == 1 - - -def check_model_parallel_rank(rank): - for i in range(16): - if rank in [i, i+16]: - assert gpc.get_local_rank(ParallelMode.MODEL) == i - - -def check_tensor_parallel_rank(rank): - tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - - for i in range(8): - ranks = list(range(i, 32, 8)) - if rank in ranks: - assert tp_rank == i, f'{rank}:{tp_rank}' - - -def check_2p5d_parallel_rank(rank): - rp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - cp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - dp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) - xp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ) - - # check for row parallel group - for i in range(2): - ranks = list(range(i, 32, 2)) - if rank in ranks: - assert rp_rank == i - - # check for col parallel group - for i in range(2): - ranks = list(range(i * 2, 32, 4)) - ranks_plus_ones = [val + 1 for val in ranks] - ranks.extend(ranks_plus_ones) - if rank in ranks: - assert cp_rank == i - - # check for depth parallel group - for i in range(2): - ranks = [] - for j in range(i * 4, 32, 8): - ranks.extend([j + k for k in range(4)]) - if rank in ranks: - assert dp_rank == i - - # check for xz parallel group - for i in range(2): - ranks = list(range(i * 2, 32, 8)) - ranks_plus_one = [val + 1 for val in ranks] - ranks.extend(ranks_plus_one) - if rank in ranks: - assert xp_rank == i - - -def init_2halfd(rank, world_size, backend, port, host): - dist_args = dict( - config=CONFIG_PATH, - rank=rank, - world_size=world_size, - backend=backend, - port=port, - host=host, - verbose=True - ) - launch(**dist_args) - check_data_parallel_rank(rank) - check_pipeline_parallel_rank(rank) - check_tensor_parallel_rank(rank) - check_2p5d_parallel_rank(rank) - check_model_parallel_rank(rank) - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.cpu -def test_2halfd_init(): - """ - As no computation or communication is done, we can run this test on CPU. - """ - world_size = 32 - test_fn = partial(init_2halfd, - world_size=world_size, - backend='gloo', - port=free_port(), - host='localhost' - ) - mp.spawn(test_fn, nprocs=world_size) - - -if __name__ == '__main__': - test_2halfd_init() diff --git a/tests/test_context/test_3d_init.py b/tests/test_context/test_3d_init.py deleted file mode 100644 index 12f0f1ea5..000000000 --- a/tests/test_context/test_3d_init.py +++ /dev/null @@ -1,120 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial -from pathlib import Path - -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.utils import free_port - -CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_3d_init.py').absolute() - - -def check_data_parallel_rank(rank): - dp_rank = gpc.get_local_rank(ParallelMode.DATA) - - if rank in list(range(16)): - assert dp_rank == 0 - elif rank in list(range(16, 32)): - assert dp_rank == 1 - - -def check_pipeline_parallel_rank(rank): - ppr = gpc.get_local_rank(ParallelMode.PIPELINE) - - if rank in list(range(8)): - assert ppr == 0 - elif rank in list(range(8, 16)): - assert ppr == 1 - elif rank in list(range(16, 24)): - assert ppr == 0 - elif rank in list(range(24, 32)): - assert ppr == 1 - - -def check_model_parallel_rank(rank): - for i in range(16): - if rank in [i, i+16]: - assert gpc.get_local_rank(ParallelMode.MODEL) == i - - -def check_tensor_parallel_rank(rank): - tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - - for i in range(8): - ranks = list(range(i, 32, 8)) - if rank in ranks: - assert tp_rank == i - - -def check_3d_parallel_rank(rank): - ip_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) - wp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) - op_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) - - # check for input parallel group - for i in range(2): - _ranks = list(range(i * 2, 32, 4)) - _ranks_plus_one = [val + 1 for val in _ranks] - input_ranks = _ranks + _ranks_plus_one - if rank in input_ranks: - assert ip_rank == i - - # check for weight parallel group - for i in range(2): - ranks = list(range(i, 32, 2)) - - if rank in ranks: - assert wp_rank == i - - # check for output parallel group - for i in range(2): - ranks = [] - for j in range(i * 4, 32, 8): - ranks.extend([j + k for k in range(4)]) - if rank in ranks: - assert op_rank == i - - -def init_3d(rank, world_size, backend, port, host): - dist_args = dict( - config=CONFIG_PATH, - rank=rank, - world_size=world_size, - backend=backend, - port=port, - host=host, - verbose=True - ) - launch(**dist_args) - check_tensor_parallel_rank(rank) - check_3d_parallel_rank(rank) - check_data_parallel_rank(rank) - check_pipeline_parallel_rank(rank) - check_model_parallel_rank(rank) - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.cpu -def test_3d_init(): - """ - As no computation or communication is done, we can run this test on CPU. - """ - world_size = 32 - test_fn = partial(init_3d, - world_size=world_size, - backend='gloo', - port=free_port(), - host='localhost' - ) - mp.spawn(test_fn, nprocs=world_size) - - -if __name__ == '__main__': - test_3d_init() diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py new file mode 100644 index 000000000..d4075ef0b --- /dev/null +++ b/tests/test_context/test_hybrid_parallel.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial +from pathlib import Path +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai import launch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import free_port +from colossalai.context import reset_seeds +from colossalai.global_variables import tensor_parallel_env as tp_env + +CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) + + +def check_data_parallel_rank(rank): + global_world_size = gpc.get_world_size(ParallelMode.GLOBAL) + mp_size = gpc.get_world_size(ParallelMode.MODEL) + num_dp_groups = global_world_size // mp_size + dp_local_rank = gpc.get_local_rank(ParallelMode.DATA) + + assert gpc.get_world_size(ParallelMode.DATA) == num_dp_groups + + for group_idx in range(num_dp_groups): + ranks_in_dp_group = range(group_idx * mp_size, (group_idx + 1) * mp_size) + if rank in ranks_in_dp_group: + assert dp_local_rank == group_idx + + +def check_pipeline_parallel_rank(rank): + mp_world_size = gpc.get_world_size(ParallelMode.MODEL) + tp_world_size = gpc.get_world_size(ParallelMode.TENSOR) + num_pipeline_stage = mp_world_size // tp_world_size + pipeline_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + for stage_idx in range(num_pipeline_stage): + ranks_in_current_stage = range(stage_idx * tp_world_size, (stage_idx + 1) * tp_world_size) + if rank in ranks_in_current_stage: + assert stage_idx == pipeline_local_rank + + +def check_model_parallel_rank(rank): + mp_size = gpc.get_world_size(ParallelMode.MODEL) + rank_within_mp_group = rank % mp_size + mp_local_rank = gpc.get_local_rank(ParallelMode.MODEL) + assert rank_within_mp_group == mp_local_rank + + +def check_tensor_parallel_rank(rank): + if tp_env.mode == '2d': + check_2d_tensor_parallel_rank(rank) + elif tp_env == '2.5d': + check_2p5d_tensor_parallel_rank(rank) + elif tp_env == '3d': + check_3d_tensor_parallel_rank(rank) + + +def get_tp_info(): + global_world_size = gpc.get_world_size(ParallelMode.GLOBAL) + tp_world_size = gpc.get_world_size(ParallelMode.TENSOR) + num_tp_groups = global_world_size // tp_world_size + tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR) + return tp_local_rank, tp_world_size, num_tp_groups + + +def check_2d_tensor_parallel_rank(rank): + tp_local_rank, tp_world_size, num_tp_groups = get_tp_info() + + for group_id in range(num_tp_groups): + ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size) + + if rank in ranks_in_current_tp_group: + col_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + row_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + + assert col_local_rank == tp_local_rank // tp_env.summa_dim + assert row_local_rank == tp_local_rank % tp_env.summa_dim + + +def check_2p5d_tensor_parallel_rank(rank): + tp_local_rank, tp_world_size, num_tp_groups = get_tp_info() + + for group_id in range(num_tp_groups): + ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size) + + if rank in ranks_in_current_tp_group: + rp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + cp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + dp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + xp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ) + + assert rp_rank == tp_local_rank % tp_env.summa_dim + assert cp_rank == tp_local_rank // tp_env.tesseract_dim + assert dp_rank == tp_local_rank // (tp_env.summa_dim**2) + assert xp_rank == tp_local_rank // tp_env.summa_dim + + +def check_3d_tensor_parallel_rank(rank): + tp_local_rank, tp_world_size, num_tp_groups = get_tp_info() + + for group_id in range(num_tp_groups): + ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size) + + if rank in ranks_in_current_tp_group: + ip_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT) + wp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT) + op_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT) + + assert ip_rank == tp_local_rank % tp_env.depth_3d + assert wp_rank == tp_local_rank // tp_env.depth_3d + assert op_rank == tp_local_rank // (tp_env.depth_3d**2) + + +def init_context(config_path, rank, world_size, backend, port, host): + dist_args = dict(config=config_path, + rank=rank, + world_size=world_size, + backend=backend, + port=port, + host=host, + verbose=True) + launch(**dist_args) + + check_tensor_parallel_rank(rank) + check_data_parallel_rank(rank) + check_pipeline_parallel_rank(rank) + check_model_parallel_rank(rank) + gpc.destroy() + torch.cuda.empty_cache() + + +def run_dist(rank, world_size, backend, port_list, host): + for config_path, port in zip(CONFIG_PATH_LIST, port_list): + init_context(config_path=config_path, rank=rank, world_size=world_size, backend=backend, port=port, host=host) + reset_seeds() + + +@pytest.mark.cpu +def test_context(): + """ + As no computation or communication is done, we can run this test on CPU. + """ + world_size = 32 + port_list = [] + + for _ in range(len(CONFIG_PATH_LIST)): + while True: + port = free_port() + if port not in port_list: + port_list.append(port) + break + + test_fn = partial(run_dist, world_size=world_size, backend='gloo', port_list=port_list, host='localhost') + mp.spawn(test_fn, nprocs=world_size) + + +if __name__ == '__main__': + test_context()