[autoparallel] Add alpha beta (#1973)

* Add alpha beta

* Fix test

* Fix test
pull/1950/head^2
Genghan Zhang 2 years ago committed by GitHub
parent cc0ed7cf33
commit 6630d45546
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,4 @@
from .calc_pipeline_strategy import alpa_dp
from .profile_alpha_beta import profile_alpha_beta
__all__ = ['profile_alpha_beta', 'alpa_dp']

@ -0,0 +1,127 @@
from math import pow
import numpy as np
def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
submesh_choices = []
i = 1
p = -1
while i <= num_devices_per_host:
i *= 2
p += 1
assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
f"while now num_devices_per_host = {num_devices_per_host}")
if mode == "alpa":
for i in range(p + 1):
submesh_choices.append((1, pow(2, i)))
for i in range(2, num_hosts + 1):
submesh_choices.append((i, num_devices_per_host))
elif mode == "new":
for i in range(p // 2 + 1):
for j in range(i, p - i + 1):
submesh_choices.append((pow(2, i), pow(2, j)))
return submesh_choices
def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
best_configs):
"""Implementation of Alpa DP for pipeline strategy
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
Arguments:
num_layers: K
num_devices: N*M
num_microbatches: B
submesh_choices: List[(n_i,m_i)]
compute_cost: t_intra
"""
# For f, layer ID start from 0
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1), 0.0, dtype=np.float32)
f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32)
f[0, num_layers, 0] = 0
for s in range(1, num_layers + 1):
for k in range(num_layers - 1, -1, -1):
for d in range(1, num_devices + 1):
for m, submesh in enumerate(submesh_choices):
n_submesh_devices = np.prod(np.array(submesh))
if n_submesh_devices <= d:
# TODO: [luzgh]: Why alpa needs max_n_succ_stages? Delete.
# if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]:
# ...
for i in range(num_layers, k, -1):
stage_cost = compute_cost[k, i, m]
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
f[s, k, d] = new_cost
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
best_s = -1
best_total_cost = np.inf
for s in range(1, num_layers + 1):
if f[s, 0, num_devices] < best_total_cost:
best_s = s
best_total_cost = f[s, 0, num_devices]
if np.isinf(best_total_cost):
return np.inf, None
total_cost = f[best_s, 0, num_devices] + (num_microbatches - 1) * f_stage_max[best_s, 0, num_devices]
current_s = best_s
current_layer = 0
current_devices = num_devices
res = []
while current_s > 0 and current_layer < num_layers and current_devices > 0:
next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
assert next_start_layer != -1 and current_devices != -1
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
current_s -= 1
current_layer = next_start_layer
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
return total_cost, res
def alpa_dp(num_layers,
num_devices,
num_microbatches,
submesh_choices,
num_autosharding_configs,
compute_cost,
gap=1e-6):
"""Alpa auto stage dynamic programming.
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Arguments:
submesh_choices: List[(int,int)]
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
"""
assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
num_autosharding_configs), "Cost shape wrong."
all_possible_stage_costs = np.sort(np.unique(compute_cost))
best_cost = np.inf
best_solution = None
last_max_stage_cost = 0.0
# TODO: [luzgh]: Why alpa needs the num_autosharding_configs dimension in compute_cost?
# In dp_impl it seems the argmin n_config will be chosen. Just amin here.
best_configs = np.argmin(compute_cost, axis=3)
best_compute_cost = np.amin(compute_cost, axis=3)
assert len(all_possible_stage_costs), "no solution in auto stage construction."
for max_stage_cost in all_possible_stage_costs:
if max_stage_cost * num_microbatches >= best_cost:
break
if max_stage_cost - last_max_stage_cost < gap:
continue
cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
max_stage_cost, best_configs)
if cost < best_cost:
best_cost = cost
best_solution = solution
last_max_stage_cost = max_stage_cost
return best_cost, best_solution

@ -0,0 +1,120 @@
import fcntl
import math
import os
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
MB = int((1 << 10) * 1e3)
GB = int((1 << 20) * 1e3)
Byte = 4
FRAMEWORK = 0
NON_SENSE = (0.1, 0.1)
def printflock(*msgs):
""" solves multi-process interleaved print problem """
with open(__file__, "r") as fh:
fcntl.flock(fh, fcntl.LOCK_EX)
try:
print(*msgs)
finally:
fcntl.flock(fh, fcntl.LOCK_UN)
def profile(device1d, nbytes, ctype):
warmup = 5
repeat = 25
rank = dist.get_rank()
src_device_num = device1d[0]
wsize = len(device1d)
group = dist.new_group(device1d)
torch.cuda.set_device(rank)
device = torch.device("cuda", rank)
buf = torch.randn(nbytes // 4).to(device)
torch.cuda.synchronize()
# warmup
for _ in range(warmup):
if ctype == "a":
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group)
elif ctype == "b":
dist.broadcast(buf, src=src_device_num, group=group)
torch.cuda.synchronize()
dist.barrier()
begin = time.perf_counter()
for _ in range(repeat):
if ctype == "a":
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group)
elif ctype == "b":
dist.broadcast(buf, src=src_device_num, group=group)
torch.cuda.synchronize()
end = time.perf_counter()
dist.barrier()
if rank == src_device_num:
avg_time_s = (end - begin) / repeat - FRAMEWORK
alg_band = nbytes / avg_time_s
if ctype == "b":
bus_band = alg_band
elif ctype == "a":
bus_band = 2 * (wsize - 1) / wsize * alg_band
print(
f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s"
)
return (avg_time_s, alg_band)
else:
return NON_SENSE # Just a placeholder
def profile_latency(device1d, it=3, ctype="a"):
latency = []
for i in range(it):
nbytes = int(Byte << i)
(t, _) = profile(device1d, nbytes, ctype)
latency.append(t)
return min(latency)
def profile_bandwidth(device1d, maxbytes, ctype="a"):
(_, bandwidth) = profile(device1d, maxbytes, ctype)
return bandwidth
def profile_ab(rank, *args):
wsize = int(torch.cuda.device_count())
device1d = args[0]
return_dict = args[1]
ctype = args[2]
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29020'
dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://', world_size=wsize, rank=rank)
device = torch.device("cuda", rank)
max_nbytes = torch.tensor(torch.cuda.mem_get_info(device)[0]).to(device)
max_nbytes = min(int(4 * GB), int(GB << int(math.log2(max_nbytes.item() / GB))))
if rank == device1d[0]:
print(f"max_nbytes: {max_nbytes} B")
alpha = profile_latency(device1d, it=5, ctype=ctype)
beta = 1 / profile_bandwidth(device1d, maxbytes=max_nbytes, ctype=ctype)
if rank == device1d[0]:
print(f"alpha(us): {round(alpha * 1e6,2)}, beta(us/GB): {round(beta * 1e6 * GB,2)}")
return_dict[rank] = (alpha, beta)
def profile_alpha_beta(device1d):
assert torch.cuda.is_available()
assert len(device1d) > 0 and len(device1d) <= int(torch.cuda.device_count())
manager = mp.Manager()
return_dict = manager.dict()
ctype = "a"
mp.spawn(profile_ab, args=[device1d, return_dict, ctype], nprocs=int(torch.cuda.device_count()))
return return_dict[device1d[0]]

@ -0,0 +1,14 @@
import pytest
from colossalai.device import profile_alpha_beta
@pytest.mark.skip(reason="Skip because assertion fails for CI devices")
def test_profile_alpha_beta():
physical_devices = [0, 1, 2, 3]
(alpha, beta) = profile_alpha_beta(physical_devices)
assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10
if __name__ == '__main__':
test_profile_alpha_beta()
Loading…
Cancel
Save