diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py index e69de29bb..879b60c06 100644 --- a/colossalai/device/__init__.py +++ b/colossalai/device/__init__.py @@ -0,0 +1,4 @@ +from .calc_pipeline_strategy import alpa_dp +from .profile_alpha_beta import profile_alpha_beta + +__all__ = ['profile_alpha_beta', 'alpa_dp'] diff --git a/colossalai/device/calc_pipeline_strategy.py b/colossalai/device/calc_pipeline_strategy.py new file mode 100644 index 000000000..4ab72dfe6 --- /dev/null +++ b/colossalai/device/calc_pipeline_strategy.py @@ -0,0 +1,127 @@ +from math import pow + +import numpy as np + + +def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): + submesh_choices = [] + i = 1 + p = -1 + while i <= num_devices_per_host: + i *= 2 + p += 1 + assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, " + f"while now num_devices_per_host = {num_devices_per_host}") + if mode == "alpa": + for i in range(p + 1): + submesh_choices.append((1, pow(2, i))) + for i in range(2, num_hosts + 1): + submesh_choices.append((i, num_devices_per_host)) + elif mode == "new": + for i in range(p // 2 + 1): + for j in range(i, p - i + 1): + submesh_choices.append((pow(2, i), pow(2, j))) + return submesh_choices + + +def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, + best_configs): + """Implementation of Alpa DP for pipeline strategy + Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf + + Arguments: + num_layers: K + num_devices: N*M + num_microbatches: B + submesh_choices: List[(n_i,m_i)] + compute_cost: t_intra + """ + # For f, layer ID start from 0 + # f[#pipeline stages, layer id that is currently being considered, number of devices used] + f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) + f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1), 0.0, dtype=np.float32) + f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32) + f[0, num_layers, 0] = 0 + for s in range(1, num_layers + 1): + for k in range(num_layers - 1, -1, -1): + for d in range(1, num_devices + 1): + for m, submesh in enumerate(submesh_choices): + n_submesh_devices = np.prod(np.array(submesh)) + if n_submesh_devices <= d: + # TODO: [luzgh]: Why alpa needs max_n_succ_stages? Delete. + # if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]: + # ... + for i in range(num_layers, k, -1): + stage_cost = compute_cost[k, i, m] + new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost + if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]): + f[s, k, d] = new_cost + f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices]) + f_argmin[s, k, d] = (i, m, best_configs[k, i, m]) + best_s = -1 + best_total_cost = np.inf + for s in range(1, num_layers + 1): + if f[s, 0, num_devices] < best_total_cost: + best_s = s + best_total_cost = f[s, 0, num_devices] + + if np.isinf(best_total_cost): + return np.inf, None + + total_cost = f[best_s, 0, num_devices] + (num_microbatches - 1) * f_stage_max[best_s, 0, num_devices] + current_s = best_s + current_layer = 0 + current_devices = num_devices + + res = [] + while current_s > 0 and current_layer < num_layers and current_devices > 0: + next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices]) + assert next_start_layer != -1 and current_devices != -1 + res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice)) + current_s -= 1 + current_layer = next_start_layer + current_devices -= np.prod(np.array(submesh_choices[submesh_choice])) + assert (current_s == 0 and current_layer == num_layers and current_devices == 0) + + return total_cost, res + + +def alpa_dp(num_layers, + num_devices, + num_microbatches, + submesh_choices, + num_autosharding_configs, + compute_cost, + gap=1e-6): + """Alpa auto stage dynamic programming. + Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py + + Arguments: + submesh_choices: List[(int,int)] + num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh) + compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs) + """ + assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices), + num_autosharding_configs), "Cost shape wrong." + all_possible_stage_costs = np.sort(np.unique(compute_cost)) + best_cost = np.inf + best_solution = None + last_max_stage_cost = 0.0 + # TODO: [luzgh]: Why alpa needs the num_autosharding_configs dimension in compute_cost? + # In dp_impl it seems the argmin n_config will be chosen. Just amin here. + best_configs = np.argmin(compute_cost, axis=3) + best_compute_cost = np.amin(compute_cost, axis=3) + assert len(all_possible_stage_costs), "no solution in auto stage construction." + for max_stage_cost in all_possible_stage_costs: + if max_stage_cost * num_microbatches >= best_cost: + break + if max_stage_cost - last_max_stage_cost < gap: + continue + cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, + max_stage_cost, best_configs) + if cost < best_cost: + best_cost = cost + best_solution = solution + last_max_stage_cost = max_stage_cost + + return best_cost, best_solution diff --git a/colossalai/device/profile_alpha_beta.py b/colossalai/device/profile_alpha_beta.py new file mode 100644 index 000000000..2d053ddbe --- /dev/null +++ b/colossalai/device/profile_alpha_beta.py @@ -0,0 +1,120 @@ +import fcntl +import math +import os +import time + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +MB = int((1 << 10) * 1e3) +GB = int((1 << 20) * 1e3) +Byte = 4 +FRAMEWORK = 0 +NON_SENSE = (0.1, 0.1) + + +def printflock(*msgs): + """ solves multi-process interleaved print problem """ + with open(__file__, "r") as fh: + fcntl.flock(fh, fcntl.LOCK_EX) + try: + print(*msgs) + finally: + fcntl.flock(fh, fcntl.LOCK_UN) + + +def profile(device1d, nbytes, ctype): + warmup = 5 + repeat = 25 + rank = dist.get_rank() + src_device_num = device1d[0] + wsize = len(device1d) + group = dist.new_group(device1d) + + torch.cuda.set_device(rank) + device = torch.device("cuda", rank) + buf = torch.randn(nbytes // 4).to(device) + + torch.cuda.synchronize() + # warmup + for _ in range(warmup): + if ctype == "a": + dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group) + elif ctype == "b": + dist.broadcast(buf, src=src_device_num, group=group) + torch.cuda.synchronize() + + dist.barrier() + begin = time.perf_counter() + for _ in range(repeat): + if ctype == "a": + dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group) + elif ctype == "b": + dist.broadcast(buf, src=src_device_num, group=group) + torch.cuda.synchronize() + end = time.perf_counter() + dist.barrier() + + if rank == src_device_num: + avg_time_s = (end - begin) / repeat - FRAMEWORK + alg_band = nbytes / avg_time_s + if ctype == "b": + bus_band = alg_band + elif ctype == "a": + bus_band = 2 * (wsize - 1) / wsize * alg_band + print( + f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s" + ) + return (avg_time_s, alg_band) + else: + return NON_SENSE # Just a placeholder + + +def profile_latency(device1d, it=3, ctype="a"): + latency = [] + for i in range(it): + nbytes = int(Byte << i) + (t, _) = profile(device1d, nbytes, ctype) + latency.append(t) + return min(latency) + + +def profile_bandwidth(device1d, maxbytes, ctype="a"): + (_, bandwidth) = profile(device1d, maxbytes, ctype) + return bandwidth + + +def profile_ab(rank, *args): + wsize = int(torch.cuda.device_count()) + device1d = args[0] + return_dict = args[1] + ctype = args[2] + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '29020' + dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://', world_size=wsize, rank=rank) + + device = torch.device("cuda", rank) + max_nbytes = torch.tensor(torch.cuda.mem_get_info(device)[0]).to(device) + max_nbytes = min(int(4 * GB), int(GB << int(math.log2(max_nbytes.item() / GB)))) + + if rank == device1d[0]: + print(f"max_nbytes: {max_nbytes} B") + + alpha = profile_latency(device1d, it=5, ctype=ctype) + beta = 1 / profile_bandwidth(device1d, maxbytes=max_nbytes, ctype=ctype) + + if rank == device1d[0]: + print(f"alpha(us): {round(alpha * 1e6,2)}, beta(us/GB): {round(beta * 1e6 * GB,2)}") + return_dict[rank] = (alpha, beta) + + +def profile_alpha_beta(device1d): + assert torch.cuda.is_available() + assert len(device1d) > 0 and len(device1d) <= int(torch.cuda.device_count()) + + manager = mp.Manager() + return_dict = manager.dict() + ctype = "a" + mp.spawn(profile_ab, args=[device1d, return_dict, ctype], nprocs=int(torch.cuda.device_count())) + return return_dict[device1d[0]] diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py new file mode 100644 index 000000000..5b076fdf0 --- /dev/null +++ b/tests/test_device/test_alpha_beta.py @@ -0,0 +1,14 @@ +import pytest + +from colossalai.device import profile_alpha_beta + + +@pytest.mark.skip(reason="Skip because assertion fails for CI devices") +def test_profile_alpha_beta(): + physical_devices = [0, 1, 2, 3] + (alpha, beta) = profile_alpha_beta(physical_devices) + assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10 + + +if __name__ == '__main__': + test_profile_alpha_beta()