[device] find best logical mesh

pull/2389/head
Jiarui Fang 2023-01-07 14:04:30 +08:00 committed by GitHub
commit 4e96039649
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 265 additions and 3 deletions

View File

@ -21,7 +21,7 @@ class AlphaBetaProfiler:
# multi-process with multi-gpu in mpi style.
>>> physical_devices = [0, 1, 4, 5]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> ab_dict = profiler.profile_ab()
>>> ab_dict = profiler.alpha_beta_dict
>>> print(ab_dict)
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
@ -31,13 +31,16 @@ class AlphaBetaProfiler:
def __init__(self,
physical_devices: List[int],
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
ctype: str = 'a',
warmup: int = 5,
repeat: int = 25,
latency_iters: int = 5):
latency_iters: int = 5,
homogeneous_tolerance: float = 0.1):
'''
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
ctype: 'a' for all-reduce, 'b' for broadcast.
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
@ -49,8 +52,13 @@ class AlphaBetaProfiler:
self.warmup = warmup
self.repeat = repeat
self.latency_iters = latency_iters
self.homogeneous_tolerance = homogeneous_tolerance
self.process_group_dict = None
self._init_profiling()
if alpha_beta_dict is None:
self.alpha_beta_dict = self.profile_ab()
else:
self.alpha_beta_dict = alpha_beta_dict
def _init_profiling(self):
# Create process group list based on its global rank
@ -139,7 +147,7 @@ class AlphaBetaProfiler:
return latency
def profile_bandwidth(self, process_group, pg_handler, maxbytes):
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
'''
This function is used to profile the bandwidth of the given process group.
@ -159,6 +167,7 @@ class AlphaBetaProfiler:
'''
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank()
global_pg_handler = dist.new_group(self.physical_devices)
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group
@ -197,3 +206,181 @@ class AlphaBetaProfiler:
alpha_beta_dict.update(symmetry_ab_dict)
return alpha_beta_dict
def search_best_logical_mesh(self):
'''
This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps:
1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict
are homogeneous if the beta value is close enough.
2. Find the best homogeneous device group contains all the physical devices. The best homogeneous
device group means the lowest beta value in the groups which contains all the physical devices.
And the reason we require the group contains all the physical devices is that the devices not in
the group will decrease the bandwidth of the group.
3. If the best homogeneous device group is found, we will construct the largest ring for each device
based on the best homogeneous device group, and the best logical mesh will be the union of all the
rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for
4 devices.
Returns:
best_logical_mesh: The best logical mesh for the given device list.
Usage:
>>> physical_devices = [0, 1, 2, 3]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh)
[[0, 1], [2, 3]]
'''
def _power_of_two(integer):
return integer & (integer - 1) == 0
def _detect_homogeneous_device(alpha_beta_dict):
'''
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta.
'''
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
for process_group, (_, beta) in alpha_beta_dict.items():
if homogeneous_device_dict is None:
homogeneous_device_dict[beta] = []
homogeneous_device_dict[beta].append(process_group)
match_beta = None
for beta_value in homogeneous_device_dict.keys():
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
1 - self.homogeneous_tolerance):
match_beta = beta_value
break
if match_beta is not None:
homogeneous_device_dict[match_beta].append(process_group)
else:
homogeneous_device_dict[beta] = []
homogeneous_device_dict[beta].append(process_group)
return homogeneous_device_dict
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
'''
This function is used to check whether the homogeneous_group contains all physical devices.
'''
flatten_mesh = []
for process_group in homogeneous_group:
flatten_mesh.extend(process_group)
non_duplicated_flatten_mesh = set(flatten_mesh)
return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
'''
This function is used to construct the largest ring in the homogeneous_group for each rank.
'''
# Construct the ring
ring = []
ranks_in_ring = []
for rank in self.physical_devices:
if rank in ranks_in_ring:
continue
stable_status = False
ring_for_rank = []
ring_for_rank.append(rank)
check_rank_list = [rank]
rank_to_check_list = []
while not stable_status:
stable_status = True
check_rank_list.extend(rank_to_check_list)
rank_to_check_list = []
for i in range(len(check_rank_list)):
check_rank = check_rank_list.pop()
for process_group in homogeneous_group:
if check_rank in process_group:
rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
if rank_to_append not in ring_for_rank:
stable_status = False
rank_to_check_list.append(rank_to_append)
ring_for_rank.append(rank_to_append)
ring.append(ring_for_rank)
ranks_in_ring.extend(ring_for_rank)
return ring
assert _power_of_two(self.world_size)
power_of_two = int(math.log2(self.world_size))
median = power_of_two // 2
balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
balanced_logical_mesh = []
for row_index in range(row_size):
balanced_logical_mesh.append([])
for column_index in range(column_size):
balanced_logical_mesh[row_index].append(self.physical_devices[row_index * column_size + column_index])
homogeneous_device_dict = _detect_homogeneous_device(self.alpha_beta_dict)
beta_list = [b for b in homogeneous_device_dict.keys()]
beta_list.sort()
beta_list.reverse()
homogeneous_types = len(beta_list)
best_logical_mesh = None
if homogeneous_types >= 2:
for _ in range(homogeneous_types - 1):
lowest_beta = beta_list.pop()
best_homogeneous_group = homogeneous_device_dict[lowest_beta]
# if the best homogeneous group contains all physical devices,
# we will build the logical device mesh based on it. Otherwise,
# we will check next level homogeneous group.
if _check_contain_all_devices(best_homogeneous_group):
# We choose the largest ring for each rank to maximum the best bus utilization.
best_logical_mesh = _construct_largest_ring(best_homogeneous_group)
break
if homogeneous_types == 1 or best_logical_mesh is None:
# in this case, we use balanced logical mesh as the best
# logical mesh.
best_logical_mesh = balanced_logical_mesh
return best_logical_mesh
def extract_alpha_beta_for_device_mesh(self):
'''
Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh.
Usage:
>>> physical_devices = [0, 1, 2, 3]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
>>> print(mesh_alpha)
[2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12]
'''
best_logical_mesh = self.search_best_logical_mesh()
first_axis = [row[0] for row in best_logical_mesh]
second_axis = best_logical_mesh[0]
# init process group for both axes
first_axis_process_group = dist.new_group(first_axis)
second_axis_process_group = dist.new_group(second_axis)
# extract alpha and beta for both axes
def _extract_alpha_beta(pg, pg_handler):
latency = self.profile_latency(pg, pg_handler)
bandwidth = self.profile_bandwidth(pg, pg_handler)
broadcast_object = [latency, bandwidth]
dist.broadcast_object_list(broadcast_object, src=pg[0])
return broadcast_object
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
mesh_alpha = [first_latency, second_latency]
mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth]
return mesh_alpha, mesh_beta

View File

@ -0,0 +1,39 @@
from functools import partial
import pytest
import torch.multiprocessing as mp
from colossalai.device import AlphaBetaProfiler
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
def check_extract_alpha_beta(rank, physical_devices, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
profiler = AlphaBetaProfiler(physical_devices)
mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
for alpha in mesh_alpha:
assert alpha > 0 and alpha < 1e-3
for beta in mesh_beta:
assert beta > 0 and beta < 1e-10
@pytest.mark.skip(reason="Skip because assertion may fail for CI devices")
@pytest.mark.dist
@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]])
@rerun_if_address_is_in_use()
def test_profile_alpha_beta(physical_devices):
world_size = 4
run_func = partial(check_extract_alpha_beta,
physical_devices=physical_devices,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_profile_alpha_beta()

View File

@ -0,0 +1,36 @@
from functools import partial
import pytest
import torch.multiprocessing as mp
from colossalai.device import AlphaBetaProfiler
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
def check_alpha_beta(rank, physical_devices, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
profiler = AlphaBetaProfiler(physical_devices)
best_logical_mesh = profiler.search_best_logical_mesh()
if physical_devices == [0, 1, 2, 3]:
assert best_logical_mesh == [[0, 1], [2, 3]]
elif physical_devices == [0, 3]:
assert best_logical_mesh == [[0, 3]]
@pytest.mark.skip(reason="Skip because assertion may fail for CI devices")
@pytest.mark.dist
@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]])
@rerun_if_address_is_in_use()
def test_profile_alpha_beta(physical_devices):
world_size = 4
run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_profile_alpha_beta()