From b5a3a4a65f1a3196faaaf0affe2c3d6ff8f7acb1 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Thu, 5 Jan 2023 17:21:29 +0800 Subject: [PATCH] [device] find best logical mesh --- colossalai/device/alpha_beta_profiler.py | 193 +++++++++++++++++- tests/test_device/test_extract_alpha_beta.py | 39 ++++ .../test_search_logical_device_mesh.py | 36 ++++ 3 files changed, 265 insertions(+), 3 deletions(-) create mode 100644 tests/test_device/test_extract_alpha_beta.py create mode 100644 tests/test_device/test_search_logical_device_mesh.py diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py index 324acacb8..9c66cb85d 100644 --- a/colossalai/device/alpha_beta_profiler.py +++ b/colossalai/device/alpha_beta_profiler.py @@ -21,7 +21,7 @@ class AlphaBetaProfiler: # multi-process with multi-gpu in mpi style. >>> physical_devices = [0, 1, 4, 5] >>> ab_profiler = AlphaBetaProfiler(physical_devices) - >>> ab_dict = profiler.profile_ab() + >>> ab_dict = profiler.alpha_beta_dict >>> print(ab_dict) {(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11), (1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12), @@ -31,13 +31,16 @@ class AlphaBetaProfiler: def __init__(self, physical_devices: List[int], + alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, ctype: str = 'a', warmup: int = 5, repeat: int = 25, - latency_iters: int = 5): + latency_iters: int = 5, + homogeneous_tolerance: float = 0.1): ''' Args: physical_devices: A list of device id, each element inside it is the global rank of that device. + alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs. ctype: 'a' for all-reduce, 'b' for broadcast. warmup: Number of warmup iterations. repeat: Number of iterations to measure. @@ -49,8 +52,13 @@ class AlphaBetaProfiler: self.warmup = warmup self.repeat = repeat self.latency_iters = latency_iters + self.homogeneous_tolerance = homogeneous_tolerance self.process_group_dict = None self._init_profiling() + if alpha_beta_dict is None: + self.alpha_beta_dict = self.profile_ab() + else: + self.alpha_beta_dict = alpha_beta_dict def _init_profiling(self): # Create process group list based on its global rank @@ -139,7 +147,7 @@ class AlphaBetaProfiler: return latency - def profile_bandwidth(self, process_group, pg_handler, maxbytes): + def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)): ''' This function is used to profile the bandwidth of the given process group. @@ -159,6 +167,7 @@ class AlphaBetaProfiler: ''' alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {} rank = dist.get_rank() + global_pg_handler = dist.new_group(self.physical_devices) def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): assert rank in process_group @@ -197,3 +206,181 @@ class AlphaBetaProfiler: alpha_beta_dict.update(symmetry_ab_dict) return alpha_beta_dict + + def search_best_logical_mesh(self): + ''' + This method is used to search the best logical mesh for the given device list. + + The best logical mesh is searched in following steps: + 1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict + are homogeneous if the beta value is close enough. + 2. Find the best homogeneous device group contains all the physical devices. The best homogeneous + device group means the lowest beta value in the groups which contains all the physical devices. + And the reason we require the group contains all the physical devices is that the devices not in + the group will decrease the bandwidth of the group. + 3. If the best homogeneous device group is found, we will construct the largest ring for each device + based on the best homogeneous device group, and the best logical mesh will be the union of all the + rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for + 4 devices. + + Returns: + best_logical_mesh: The best logical mesh for the given device list. + + Usage: + >>> physical_devices = [0, 1, 2, 3] + >>> ab_profiler = AlphaBetaProfiler(physical_devices) + >>> best_logical_mesh = profiler.search_best_logical_mesh() + >>> print(best_logical_mesh) + [[0, 1], [2, 3]] + ''' + + def _power_of_two(integer): + return integer & (integer - 1) == 0 + + def _detect_homogeneous_device(alpha_beta_dict): + ''' + This function is used to detect whether the devices in the alpha_beta_dict are homogeneous. + + Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value + of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)] + * base_beta. + ''' + homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {} + for process_group, (_, beta) in alpha_beta_dict.items(): + if homogeneous_device_dict is None: + homogeneous_device_dict[beta] = [] + homogeneous_device_dict[beta].append(process_group) + + match_beta = None + for beta_value in homogeneous_device_dict.keys(): + if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * ( + 1 - self.homogeneous_tolerance): + match_beta = beta_value + break + + if match_beta is not None: + homogeneous_device_dict[match_beta].append(process_group) + else: + homogeneous_device_dict[beta] = [] + homogeneous_device_dict[beta].append(process_group) + + return homogeneous_device_dict + + def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]): + ''' + This function is used to check whether the homogeneous_group contains all physical devices. + ''' + flatten_mesh = [] + for process_group in homogeneous_group: + flatten_mesh.extend(process_group) + non_duplicated_flatten_mesh = set(flatten_mesh) + return len(non_duplicated_flatten_mesh) == len(self.physical_devices) + + def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): + ''' + This function is used to construct the largest ring in the homogeneous_group for each rank. + ''' + # Construct the ring + ring = [] + ranks_in_ring = [] + for rank in self.physical_devices: + if rank in ranks_in_ring: + continue + stable_status = False + ring_for_rank = [] + ring_for_rank.append(rank) + check_rank_list = [rank] + rank_to_check_list = [] + + while not stable_status: + stable_status = True + check_rank_list.extend(rank_to_check_list) + rank_to_check_list = [] + for i in range(len(check_rank_list)): + check_rank = check_rank_list.pop() + for process_group in homogeneous_group: + if check_rank in process_group: + rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1] + if rank_to_append not in ring_for_rank: + stable_status = False + rank_to_check_list.append(rank_to_append) + ring_for_rank.append(rank_to_append) + + ring.append(ring_for_rank) + ranks_in_ring.extend(ring_for_rank) + + return ring + + assert _power_of_two(self.world_size) + power_of_two = int(math.log2(self.world_size)) + median = power_of_two // 2 + balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median)) + row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1] + balanced_logical_mesh = [] + for row_index in range(row_size): + balanced_logical_mesh.append([]) + for column_index in range(column_size): + balanced_logical_mesh[row_index].append(self.physical_devices[row_index * column_size + column_index]) + + homogeneous_device_dict = _detect_homogeneous_device(self.alpha_beta_dict) + beta_list = [b for b in homogeneous_device_dict.keys()] + beta_list.sort() + beta_list.reverse() + homogeneous_types = len(beta_list) + best_logical_mesh = None + if homogeneous_types >= 2: + for _ in range(homogeneous_types - 1): + lowest_beta = beta_list.pop() + best_homogeneous_group = homogeneous_device_dict[lowest_beta] + # if the best homogeneous group contains all physical devices, + # we will build the logical device mesh based on it. Otherwise, + # we will check next level homogeneous group. + if _check_contain_all_devices(best_homogeneous_group): + # We choose the largest ring for each rank to maximum the best bus utilization. + best_logical_mesh = _construct_largest_ring(best_homogeneous_group) + break + + if homogeneous_types == 1 or best_logical_mesh is None: + # in this case, we use balanced logical mesh as the best + # logical mesh. + best_logical_mesh = balanced_logical_mesh + + return best_logical_mesh + + def extract_alpha_beta_for_device_mesh(self): + ''' + Extract the mesh_alpha list and mesh_beta list based on the + best logical mesh, which will be used to initialize the device mesh. + + Usage: + >>> physical_devices = [0, 1, 2, 3] + >>> ab_profiler = AlphaBetaProfiler(physical_devices) + >>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() + >>> print(mesh_alpha) + [2.5917552411556242e-05, 0.00010312341153621673] + >>> print(mesh_beta) + [5.875573704655635e-11, 4.7361584445959614e-12] + ''' + best_logical_mesh = self.search_best_logical_mesh() + + first_axis = [row[0] for row in best_logical_mesh] + second_axis = best_logical_mesh[0] + + # init process group for both axes + first_axis_process_group = dist.new_group(first_axis) + second_axis_process_group = dist.new_group(second_axis) + + # extract alpha and beta for both axes + def _extract_alpha_beta(pg, pg_handler): + latency = self.profile_latency(pg, pg_handler) + bandwidth = self.profile_bandwidth(pg, pg_handler) + broadcast_object = [latency, bandwidth] + dist.broadcast_object_list(broadcast_object, src=pg[0]) + return broadcast_object + + first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group) + second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group) + mesh_alpha = [first_latency, second_latency] + mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth] + + return mesh_alpha, mesh_beta diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py new file mode 100644 index 000000000..e32bebdd9 --- /dev/null +++ b/tests/test_device/test_extract_alpha_beta.py @@ -0,0 +1,39 @@ +from functools import partial + +import pytest +import torch.multiprocessing as mp + +from colossalai.device import AlphaBetaProfiler +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def check_extract_alpha_beta(rank, physical_devices, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + profiler = AlphaBetaProfiler(physical_devices) + + mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() + for alpha in mesh_alpha: + assert alpha > 0 and alpha < 1e-3 + for beta in mesh_beta: + assert beta > 0 and beta < 1e-10 + + +@pytest.mark.skip(reason="Skip because assertion may fail for CI devices") +@pytest.mark.dist +@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@rerun_if_address_is_in_use() +def test_profile_alpha_beta(physical_devices): + world_size = 4 + run_func = partial(check_extract_alpha_beta, + physical_devices=physical_devices, + world_size=world_size, + port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_profile_alpha_beta() diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py new file mode 100644 index 000000000..591eafb2a --- /dev/null +++ b/tests/test_device/test_search_logical_device_mesh.py @@ -0,0 +1,36 @@ +from functools import partial + +import pytest +import torch.multiprocessing as mp + +from colossalai.device import AlphaBetaProfiler +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port + + +def check_alpha_beta(rank, physical_devices, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + profiler = AlphaBetaProfiler(physical_devices) + best_logical_mesh = profiler.search_best_logical_mesh() + + if physical_devices == [0, 1, 2, 3]: + assert best_logical_mesh == [[0, 1], [2, 3]] + elif physical_devices == [0, 3]: + assert best_logical_mesh == [[0, 3]] + + +@pytest.mark.skip(reason="Skip because assertion may fail for CI devices") +@pytest.mark.dist +@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@rerun_if_address_is_in_use() +def test_profile_alpha_beta(physical_devices): + world_size = 4 + run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_profile_alpha_beta()