mirror of https://github.com/hpcaitech/ColossalAI
[device] find best logical mesh
commit
4e96039649
|
@ -21,7 +21,7 @@ class AlphaBetaProfiler:
|
|||
# multi-process with multi-gpu in mpi style.
|
||||
>>> physical_devices = [0, 1, 4, 5]
|
||||
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
|
||||
>>> ab_dict = profiler.profile_ab()
|
||||
>>> ab_dict = profiler.alpha_beta_dict
|
||||
>>> print(ab_dict)
|
||||
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
|
||||
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
|
||||
|
@ -31,13 +31,16 @@ class AlphaBetaProfiler:
|
|||
|
||||
def __init__(self,
|
||||
physical_devices: List[int],
|
||||
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
|
||||
ctype: str = 'a',
|
||||
warmup: int = 5,
|
||||
repeat: int = 25,
|
||||
latency_iters: int = 5):
|
||||
latency_iters: int = 5,
|
||||
homogeneous_tolerance: float = 0.1):
|
||||
'''
|
||||
Args:
|
||||
physical_devices: A list of device id, each element inside it is the global rank of that device.
|
||||
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
|
||||
ctype: 'a' for all-reduce, 'b' for broadcast.
|
||||
warmup: Number of warmup iterations.
|
||||
repeat: Number of iterations to measure.
|
||||
|
@ -49,8 +52,13 @@ class AlphaBetaProfiler:
|
|||
self.warmup = warmup
|
||||
self.repeat = repeat
|
||||
self.latency_iters = latency_iters
|
||||
self.homogeneous_tolerance = homogeneous_tolerance
|
||||
self.process_group_dict = None
|
||||
self._init_profiling()
|
||||
if alpha_beta_dict is None:
|
||||
self.alpha_beta_dict = self.profile_ab()
|
||||
else:
|
||||
self.alpha_beta_dict = alpha_beta_dict
|
||||
|
||||
def _init_profiling(self):
|
||||
# Create process group list based on its global rank
|
||||
|
@ -139,7 +147,7 @@ class AlphaBetaProfiler:
|
|||
|
||||
return latency
|
||||
|
||||
def profile_bandwidth(self, process_group, pg_handler, maxbytes):
|
||||
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
|
||||
'''
|
||||
This function is used to profile the bandwidth of the given process group.
|
||||
|
||||
|
@ -159,6 +167,7 @@ class AlphaBetaProfiler:
|
|||
'''
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
|
||||
rank = dist.get_rank()
|
||||
global_pg_handler = dist.new_group(self.physical_devices)
|
||||
|
||||
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
|
||||
assert rank in process_group
|
||||
|
@ -197,3 +206,181 @@ class AlphaBetaProfiler:
|
|||
alpha_beta_dict.update(symmetry_ab_dict)
|
||||
|
||||
return alpha_beta_dict
|
||||
|
||||
def search_best_logical_mesh(self):
|
||||
'''
|
||||
This method is used to search the best logical mesh for the given device list.
|
||||
|
||||
The best logical mesh is searched in following steps:
|
||||
1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict
|
||||
are homogeneous if the beta value is close enough.
|
||||
2. Find the best homogeneous device group contains all the physical devices. The best homogeneous
|
||||
device group means the lowest beta value in the groups which contains all the physical devices.
|
||||
And the reason we require the group contains all the physical devices is that the devices not in
|
||||
the group will decrease the bandwidth of the group.
|
||||
3. If the best homogeneous device group is found, we will construct the largest ring for each device
|
||||
based on the best homogeneous device group, and the best logical mesh will be the union of all the
|
||||
rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for
|
||||
4 devices.
|
||||
|
||||
Returns:
|
||||
best_logical_mesh: The best logical mesh for the given device list.
|
||||
|
||||
Usage:
|
||||
>>> physical_devices = [0, 1, 2, 3]
|
||||
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
|
||||
>>> best_logical_mesh = profiler.search_best_logical_mesh()
|
||||
>>> print(best_logical_mesh)
|
||||
[[0, 1], [2, 3]]
|
||||
'''
|
||||
|
||||
def _power_of_two(integer):
|
||||
return integer & (integer - 1) == 0
|
||||
|
||||
def _detect_homogeneous_device(alpha_beta_dict):
|
||||
'''
|
||||
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
|
||||
|
||||
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
|
||||
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
|
||||
* base_beta.
|
||||
'''
|
||||
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
|
||||
for process_group, (_, beta) in alpha_beta_dict.items():
|
||||
if homogeneous_device_dict is None:
|
||||
homogeneous_device_dict[beta] = []
|
||||
homogeneous_device_dict[beta].append(process_group)
|
||||
|
||||
match_beta = None
|
||||
for beta_value in homogeneous_device_dict.keys():
|
||||
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
|
||||
1 - self.homogeneous_tolerance):
|
||||
match_beta = beta_value
|
||||
break
|
||||
|
||||
if match_beta is not None:
|
||||
homogeneous_device_dict[match_beta].append(process_group)
|
||||
else:
|
||||
homogeneous_device_dict[beta] = []
|
||||
homogeneous_device_dict[beta].append(process_group)
|
||||
|
||||
return homogeneous_device_dict
|
||||
|
||||
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
|
||||
'''
|
||||
This function is used to check whether the homogeneous_group contains all physical devices.
|
||||
'''
|
||||
flatten_mesh = []
|
||||
for process_group in homogeneous_group:
|
||||
flatten_mesh.extend(process_group)
|
||||
non_duplicated_flatten_mesh = set(flatten_mesh)
|
||||
return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
|
||||
|
||||
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
|
||||
'''
|
||||
This function is used to construct the largest ring in the homogeneous_group for each rank.
|
||||
'''
|
||||
# Construct the ring
|
||||
ring = []
|
||||
ranks_in_ring = []
|
||||
for rank in self.physical_devices:
|
||||
if rank in ranks_in_ring:
|
||||
continue
|
||||
stable_status = False
|
||||
ring_for_rank = []
|
||||
ring_for_rank.append(rank)
|
||||
check_rank_list = [rank]
|
||||
rank_to_check_list = []
|
||||
|
||||
while not stable_status:
|
||||
stable_status = True
|
||||
check_rank_list.extend(rank_to_check_list)
|
||||
rank_to_check_list = []
|
||||
for i in range(len(check_rank_list)):
|
||||
check_rank = check_rank_list.pop()
|
||||
for process_group in homogeneous_group:
|
||||
if check_rank in process_group:
|
||||
rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
|
||||
if rank_to_append not in ring_for_rank:
|
||||
stable_status = False
|
||||
rank_to_check_list.append(rank_to_append)
|
||||
ring_for_rank.append(rank_to_append)
|
||||
|
||||
ring.append(ring_for_rank)
|
||||
ranks_in_ring.extend(ring_for_rank)
|
||||
|
||||
return ring
|
||||
|
||||
assert _power_of_two(self.world_size)
|
||||
power_of_two = int(math.log2(self.world_size))
|
||||
median = power_of_two // 2
|
||||
balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
|
||||
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
|
||||
balanced_logical_mesh = []
|
||||
for row_index in range(row_size):
|
||||
balanced_logical_mesh.append([])
|
||||
for column_index in range(column_size):
|
||||
balanced_logical_mesh[row_index].append(self.physical_devices[row_index * column_size + column_index])
|
||||
|
||||
homogeneous_device_dict = _detect_homogeneous_device(self.alpha_beta_dict)
|
||||
beta_list = [b for b in homogeneous_device_dict.keys()]
|
||||
beta_list.sort()
|
||||
beta_list.reverse()
|
||||
homogeneous_types = len(beta_list)
|
||||
best_logical_mesh = None
|
||||
if homogeneous_types >= 2:
|
||||
for _ in range(homogeneous_types - 1):
|
||||
lowest_beta = beta_list.pop()
|
||||
best_homogeneous_group = homogeneous_device_dict[lowest_beta]
|
||||
# if the best homogeneous group contains all physical devices,
|
||||
# we will build the logical device mesh based on it. Otherwise,
|
||||
# we will check next level homogeneous group.
|
||||
if _check_contain_all_devices(best_homogeneous_group):
|
||||
# We choose the largest ring for each rank to maximum the best bus utilization.
|
||||
best_logical_mesh = _construct_largest_ring(best_homogeneous_group)
|
||||
break
|
||||
|
||||
if homogeneous_types == 1 or best_logical_mesh is None:
|
||||
# in this case, we use balanced logical mesh as the best
|
||||
# logical mesh.
|
||||
best_logical_mesh = balanced_logical_mesh
|
||||
|
||||
return best_logical_mesh
|
||||
|
||||
def extract_alpha_beta_for_device_mesh(self):
|
||||
'''
|
||||
Extract the mesh_alpha list and mesh_beta list based on the
|
||||
best logical mesh, which will be used to initialize the device mesh.
|
||||
|
||||
Usage:
|
||||
>>> physical_devices = [0, 1, 2, 3]
|
||||
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
|
||||
>>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
|
||||
>>> print(mesh_alpha)
|
||||
[2.5917552411556242e-05, 0.00010312341153621673]
|
||||
>>> print(mesh_beta)
|
||||
[5.875573704655635e-11, 4.7361584445959614e-12]
|
||||
'''
|
||||
best_logical_mesh = self.search_best_logical_mesh()
|
||||
|
||||
first_axis = [row[0] for row in best_logical_mesh]
|
||||
second_axis = best_logical_mesh[0]
|
||||
|
||||
# init process group for both axes
|
||||
first_axis_process_group = dist.new_group(first_axis)
|
||||
second_axis_process_group = dist.new_group(second_axis)
|
||||
|
||||
# extract alpha and beta for both axes
|
||||
def _extract_alpha_beta(pg, pg_handler):
|
||||
latency = self.profile_latency(pg, pg_handler)
|
||||
bandwidth = self.profile_bandwidth(pg, pg_handler)
|
||||
broadcast_object = [latency, bandwidth]
|
||||
dist.broadcast_object_list(broadcast_object, src=pg[0])
|
||||
return broadcast_object
|
||||
|
||||
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
|
||||
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
|
||||
mesh_alpha = [first_latency, second_latency]
|
||||
mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth]
|
||||
|
||||
return mesh_alpha, mesh_beta
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.device import AlphaBetaProfiler
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
def check_extract_alpha_beta(rank, physical_devices, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
profiler = AlphaBetaProfiler(physical_devices)
|
||||
|
||||
mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
|
||||
for alpha in mesh_alpha:
|
||||
assert alpha > 0 and alpha < 1e-3
|
||||
for beta in mesh_beta:
|
||||
assert beta > 0 and beta < 1e-10
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skip because assertion may fail for CI devices")
|
||||
@pytest.mark.dist
|
||||
@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_profile_alpha_beta(physical_devices):
|
||||
world_size = 4
|
||||
run_func = partial(check_extract_alpha_beta,
|
||||
physical_devices=physical_devices,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_profile_alpha_beta()
|
|
@ -0,0 +1,36 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.device import AlphaBetaProfiler
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
def check_alpha_beta(rank, physical_devices, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
profiler = AlphaBetaProfiler(physical_devices)
|
||||
best_logical_mesh = profiler.search_best_logical_mesh()
|
||||
|
||||
if physical_devices == [0, 1, 2, 3]:
|
||||
assert best_logical_mesh == [[0, 1], [2, 3]]
|
||||
elif physical_devices == [0, 3]:
|
||||
assert best_logical_mesh == [[0, 3]]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skip because assertion may fail for CI devices")
|
||||
@pytest.mark.dist
|
||||
@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_profile_alpha_beta(physical_devices):
|
||||
world_size = 4
|
||||
run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_profile_alpha_beta()
|
Loading…
Reference in New Issue