mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
393 lines
17 KiB
393 lines
17 KiB
import math |
|
import time |
|
from typing import Dict, List, Tuple |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
GB = int((1 << 30)) |
|
BYTE = 4 |
|
FRAMEWORK_LATENCY = 0 |
|
|
|
|
|
class AlphaBetaProfiler: |
|
""" |
|
Profile alpha and beta value for a given device list. |
|
|
|
Usage: |
|
# Note: the environment of execution is supposed to be |
|
# multi-process with multi-gpu in mpi style. |
|
>>> physical_devices = [0, 1, 4, 5] |
|
>>> ab_profiler = AlphaBetaProfiler(physical_devices) |
|
>>> ab_dict = profiler.alpha_beta_dict |
|
>>> print(ab_dict) |
|
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11), |
|
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12), |
|
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11), |
|
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)} |
|
""" |
|
|
|
def __init__( |
|
self, |
|
physical_devices: List[int], |
|
alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, |
|
ctype: str = "a", |
|
warmup: int = 5, |
|
repeat: int = 25, |
|
latency_iters: int = 5, |
|
homogeneous_tolerance: float = 0.1, |
|
): |
|
""" |
|
Args: |
|
physical_devices: A list of device id, each element inside it is the global rank of that device. |
|
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs. |
|
ctype: 'a' for all-reduce, 'b' for broadcast. |
|
warmup: Number of warmup iterations. |
|
repeat: Number of iterations to measure. |
|
latency_iters: Number of iterations to measure latency. |
|
""" |
|
self.physical_devices = physical_devices |
|
self.ctype = ctype |
|
self.world_size = len(physical_devices) |
|
self.warmup = warmup |
|
self.repeat = repeat |
|
self.latency_iters = latency_iters |
|
self.homogeneous_tolerance = homogeneous_tolerance |
|
self.process_group_dict = None |
|
self._init_profiling() |
|
if alpha_beta_dict is None: |
|
self.alpha_beta_dict = self.profile_ab() |
|
else: |
|
self.alpha_beta_dict = alpha_beta_dict |
|
|
|
def _init_profiling(self): |
|
# Create process group list based on its global rank |
|
process_group_list = [] |
|
for f_index in range(self.world_size - 1): |
|
for b_index in range(f_index + 1, self.world_size): |
|
process_group_list.append((self.physical_devices[f_index], self.physical_devices[b_index])) |
|
|
|
# Create process group dict which maps process group to its handler |
|
process_group_dict = {} |
|
for process_group in process_group_list: |
|
pg_handler = dist.new_group(process_group) |
|
process_group_dict[process_group] = pg_handler |
|
|
|
self.process_group_dict = process_group_dict |
|
|
|
def _profile(self, process_group, pg_handler, nbytes): |
|
logger = get_dist_logger() |
|
rank = dist.get_rank() |
|
src_device_num = process_group[0] |
|
world_size = len(process_group) |
|
|
|
device = torch.cuda.current_device() |
|
buf = torch.randn(nbytes // 4).to(device) |
|
|
|
torch.cuda.synchronize() |
|
# warmup |
|
for _ in range(self.warmup): |
|
if self.ctype == "a": |
|
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler) |
|
elif self.ctype == "b": |
|
dist.broadcast(buf, src=src_device_num, group=pg_handler) |
|
torch.cuda.synchronize() |
|
|
|
dist.barrier(group=pg_handler) |
|
begin = time.perf_counter() |
|
for _ in range(self.repeat): |
|
if self.ctype == "a": |
|
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=pg_handler) |
|
elif self.ctype == "b": |
|
dist.broadcast(buf, src=src_device_num, group=pg_handler) |
|
torch.cuda.synchronize() |
|
end = time.perf_counter() |
|
dist.barrier(group=pg_handler) |
|
|
|
if rank == src_device_num: |
|
avg_time_s = (end - begin) / self.repeat - FRAMEWORK_LATENCY |
|
alg_band = nbytes / avg_time_s |
|
if self.ctype == "a": |
|
# convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware. |
|
bus_band = 2 * (world_size - 1) / world_size * alg_band |
|
bus_band = alg_band |
|
elif self.ctype == "b": |
|
bus_band = alg_band |
|
|
|
logger.info( |
|
f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s" |
|
) |
|
return (avg_time_s, alg_band) |
|
else: |
|
# Just a placeholder |
|
return (None, None) |
|
|
|
def profile_latency(self, process_group, pg_handler): |
|
""" |
|
This function is used to profile the latency of the given process group with a series of bytes. |
|
|
|
Args: |
|
process_group: A tuple of global rank of the process group. |
|
pg_handler: The handler of the process group. |
|
|
|
Returns: |
|
latency: None if the latency is not measured, otherwise the median of the latency_list. |
|
""" |
|
latency_list = [] |
|
for i in range(self.latency_iters): |
|
nbytes = int(BYTE << i) |
|
(t, _) = self._profile(process_group, pg_handler, nbytes) |
|
latency_list.append(t) |
|
|
|
if latency_list[0] is None: |
|
latency = None |
|
else: |
|
median_index = math.floor(self.latency_iters / 2) |
|
latency = latency_list[median_index] |
|
|
|
return latency |
|
|
|
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)): |
|
""" |
|
This function is used to profile the bandwidth of the given process group. |
|
|
|
Args: |
|
process_group: A tuple of global rank of the process group. |
|
pg_handler: The handler of the process group. |
|
""" |
|
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes) |
|
return bandwidth |
|
|
|
def profile_ab(self): |
|
""" |
|
This method is used to profiling the alpha and beta value for a given device list. |
|
|
|
Returns: |
|
alpha_beta_dict: A dict which maps process group to its alpha and beta value. |
|
""" |
|
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {} |
|
rank = dist.get_rank() |
|
dist.new_group(self.physical_devices) |
|
|
|
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): |
|
assert rank in process_group |
|
device = torch.cuda.current_device() |
|
rank_max_nbytes = torch.cuda.mem_get_info(device)[0] |
|
rank_max_nbytes = torch.tensor(rank_max_nbytes, device=device) |
|
dist.all_reduce(rank_max_nbytes, op=dist.ReduceOp.MIN, group=pg_handler) |
|
max_nbytes = min(int(1 * GB), int(GB << int(math.log2(rank_max_nbytes.item() / GB)))) |
|
return max_nbytes |
|
|
|
for process_group, pg_handler in self.process_group_dict.items(): |
|
if rank not in process_group: |
|
max_nbytes = None |
|
alpha = None |
|
bandwidth = None |
|
else: |
|
max_nbytes = get_max_nbytes(process_group, pg_handler) |
|
alpha = self.profile_latency(process_group, pg_handler) |
|
bandwidth = self.profile_bandwidth(process_group, pg_handler, maxbytes=max_nbytes) |
|
|
|
if bandwidth is None: |
|
beta = None |
|
else: |
|
beta = 1 / bandwidth |
|
|
|
broadcast_list = [alpha, beta] |
|
dist.broadcast_object_list(broadcast_list, src=process_group[0]) |
|
alpha_beta_dict[process_group] = tuple(broadcast_list) |
|
|
|
# add symmetry pair to the alpha_beta_dict |
|
symmetry_ab_dict = {} |
|
for process_group, alpha_beta_pair in alpha_beta_dict.items(): |
|
symmetry_process_group = (process_group[1], process_group[0]) |
|
symmetry_ab_dict[symmetry_process_group] = alpha_beta_pair |
|
|
|
alpha_beta_dict.update(symmetry_ab_dict) |
|
|
|
return alpha_beta_dict |
|
|
|
def search_best_logical_mesh(self): |
|
""" |
|
This method is used to search the best logical mesh for the given device list. |
|
|
|
The best logical mesh is searched in following steps: |
|
1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict |
|
are homogeneous if the beta value is close enough. |
|
2. Find the best homogeneous device group contains all the physical devices. The best homogeneous |
|
device group means the lowest beta value in the groups which contains all the physical devices. |
|
And the reason we require the group contains all the physical devices is that the devices not in |
|
the group will decrease the bandwidth of the group. |
|
3. If the best homogeneous device group is found, we will construct the largest ring for each device |
|
based on the best homogeneous device group, and the best logical mesh will be the union of all the |
|
rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for |
|
4 devices. |
|
|
|
Returns: |
|
best_logical_mesh: The best logical mesh for the given device list. |
|
|
|
Usage: |
|
>>> physical_devices = [0, 1, 2, 3] |
|
>>> ab_profiler = AlphaBetaProfiler(physical_devices) |
|
>>> best_logical_mesh = profiler.search_best_logical_mesh() |
|
>>> print(best_logical_mesh) |
|
[[0, 1], [2, 3]] |
|
""" |
|
|
|
def _power_of_two(integer): |
|
return integer & (integer - 1) == 0 |
|
|
|
def _detect_homogeneous_device(alpha_beta_dict): |
|
""" |
|
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous. |
|
|
|
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value |
|
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)] |
|
* base_beta. |
|
""" |
|
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {} |
|
for process_group, (_, beta) in alpha_beta_dict.items(): |
|
if homogeneous_device_dict is None: |
|
homogeneous_device_dict[beta] = [] |
|
homogeneous_device_dict[beta].append(process_group) |
|
|
|
match_beta = None |
|
for beta_value in homogeneous_device_dict.keys(): |
|
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * ( |
|
1 - self.homogeneous_tolerance |
|
): |
|
match_beta = beta_value |
|
break |
|
|
|
if match_beta is not None: |
|
homogeneous_device_dict[match_beta].append(process_group) |
|
else: |
|
homogeneous_device_dict[beta] = [] |
|
homogeneous_device_dict[beta].append(process_group) |
|
|
|
return homogeneous_device_dict |
|
|
|
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]): |
|
""" |
|
This function is used to check whether the homogeneous_group contains all physical devices. |
|
""" |
|
flatten_mesh = [] |
|
for process_group in homogeneous_group: |
|
flatten_mesh.extend(process_group) |
|
non_duplicated_flatten_mesh = set(flatten_mesh) |
|
return len(non_duplicated_flatten_mesh) == len(self.physical_devices) |
|
|
|
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): |
|
""" |
|
This function is used to construct the largest ring in the homogeneous_group for each rank. |
|
""" |
|
# Construct the ring |
|
ring = [] |
|
ranks_in_ring = [] |
|
for rank in self.physical_devices: |
|
if rank in ranks_in_ring: |
|
continue |
|
stable_status = False |
|
ring_for_rank = [] |
|
ring_for_rank.append(rank) |
|
check_rank_list = [rank] |
|
rank_to_check_list = [] |
|
|
|
while not stable_status: |
|
stable_status = True |
|
check_rank_list.extend(rank_to_check_list) |
|
rank_to_check_list = [] |
|
for i in range(len(check_rank_list)): |
|
check_rank = check_rank_list.pop() |
|
for process_group in homogeneous_group: |
|
if check_rank in process_group: |
|
rank_to_append = ( |
|
process_group[0] if process_group[1] == check_rank else process_group[1] |
|
) |
|
if rank_to_append not in ring_for_rank: |
|
stable_status = False |
|
rank_to_check_list.append(rank_to_append) |
|
ring_for_rank.append(rank_to_append) |
|
|
|
ring.append(ring_for_rank) |
|
ranks_in_ring.extend(ring_for_rank) |
|
|
|
return ring |
|
|
|
assert _power_of_two(self.world_size) |
|
power_of_two = int(math.log2(self.world_size)) |
|
median = power_of_two // 2 |
|
balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median)) |
|
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1] |
|
balanced_logical_mesh = [] |
|
for row_index in range(row_size): |
|
balanced_logical_mesh.append([]) |
|
for column_index in range(column_size): |
|
balanced_logical_mesh[row_index].append(self.physical_devices[row_index * column_size + column_index]) |
|
|
|
homogeneous_device_dict = _detect_homogeneous_device(self.alpha_beta_dict) |
|
beta_list = [b for b in homogeneous_device_dict.keys()] |
|
beta_list.sort() |
|
beta_list.reverse() |
|
homogeneous_types = len(beta_list) |
|
best_logical_mesh = None |
|
if homogeneous_types >= 2: |
|
for _ in range(homogeneous_types - 1): |
|
lowest_beta = beta_list.pop() |
|
best_homogeneous_group = homogeneous_device_dict[lowest_beta] |
|
# if the best homogeneous group contains all physical devices, |
|
# we will build the logical device mesh based on it. Otherwise, |
|
# we will check next level homogeneous group. |
|
if _check_contain_all_devices(best_homogeneous_group): |
|
# We choose the largest ring for each rank to maximum the best bus utilization. |
|
best_logical_mesh = _construct_largest_ring(best_homogeneous_group) |
|
break |
|
|
|
if homogeneous_types == 1 or best_logical_mesh is None: |
|
# in this case, we use balanced logical mesh as the best |
|
# logical mesh. |
|
best_logical_mesh = balanced_logical_mesh |
|
|
|
return best_logical_mesh |
|
|
|
def extract_alpha_beta_for_device_mesh(self): |
|
""" |
|
Extract the mesh_alpha list and mesh_beta list based on the |
|
best logical mesh, which will be used to initialize the device mesh. |
|
|
|
Usage: |
|
>>> physical_devices = [0, 1, 2, 3] |
|
>>> ab_profiler = AlphaBetaProfiler(physical_devices) |
|
>>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() |
|
>>> print(mesh_alpha) |
|
[2.5917552411556242e-05, 0.00010312341153621673] |
|
>>> print(mesh_beta) |
|
[5.875573704655635e-11, 4.7361584445959614e-12] |
|
""" |
|
best_logical_mesh = self.search_best_logical_mesh() |
|
|
|
first_axis = [row[0] for row in best_logical_mesh] |
|
second_axis = best_logical_mesh[0] |
|
|
|
# init process group for both axes |
|
first_axis_process_group = dist.new_group(first_axis) |
|
second_axis_process_group = dist.new_group(second_axis) |
|
|
|
# extract alpha and beta for both axes |
|
def _extract_alpha_beta(pg, pg_handler): |
|
latency = self.profile_latency(pg, pg_handler) |
|
bandwidth = self.profile_bandwidth(pg, pg_handler) |
|
broadcast_object = [latency, bandwidth] |
|
dist.broadcast_object_list(broadcast_object, src=pg[0]) |
|
return broadcast_object |
|
|
|
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group) |
|
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group) |
|
mesh_alpha = [first_latency, second_latency] |
|
# The beta values have been enlarged by 1e10 times temporarily because the computation cost |
|
# is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future. |
|
mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth] |
|
|
|
return mesh_alpha, mesh_beta
|
|
|