mirror of https://github.com/hpcaitech/ColossalAI
parent
cc0ed7cf33
commit
6630d45546
@ -0,0 +1,4 @@
|
||||
from .calc_pipeline_strategy import alpa_dp
|
||||
from .profile_alpha_beta import profile_alpha_beta
|
||||
|
||||
__all__ = ['profile_alpha_beta', 'alpa_dp']
|
@ -0,0 +1,127 @@
|
||||
from math import pow
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
|
||||
submesh_choices = []
|
||||
i = 1
|
||||
p = -1
|
||||
while i <= num_devices_per_host:
|
||||
i *= 2
|
||||
p += 1
|
||||
assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
|
||||
f"while now num_devices_per_host = {num_devices_per_host}")
|
||||
if mode == "alpa":
|
||||
for i in range(p + 1):
|
||||
submesh_choices.append((1, pow(2, i)))
|
||||
for i in range(2, num_hosts + 1):
|
||||
submesh_choices.append((i, num_devices_per_host))
|
||||
elif mode == "new":
|
||||
for i in range(p // 2 + 1):
|
||||
for j in range(i, p - i + 1):
|
||||
submesh_choices.append((pow(2, i), pow(2, j)))
|
||||
return submesh_choices
|
||||
|
||||
|
||||
def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
|
||||
best_configs):
|
||||
"""Implementation of Alpa DP for pipeline strategy
|
||||
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
|
||||
|
||||
Arguments:
|
||||
num_layers: K
|
||||
num_devices: N*M
|
||||
num_microbatches: B
|
||||
submesh_choices: List[(n_i,m_i)]
|
||||
compute_cost: t_intra
|
||||
"""
|
||||
# For f, layer ID start from 0
|
||||
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
|
||||
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
|
||||
f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1), 0.0, dtype=np.float32)
|
||||
f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3), -1, dtype=np.int32)
|
||||
f[0, num_layers, 0] = 0
|
||||
for s in range(1, num_layers + 1):
|
||||
for k in range(num_layers - 1, -1, -1):
|
||||
for d in range(1, num_devices + 1):
|
||||
for m, submesh in enumerate(submesh_choices):
|
||||
n_submesh_devices = np.prod(np.array(submesh))
|
||||
if n_submesh_devices <= d:
|
||||
# TODO: [luzgh]: Why alpa needs max_n_succ_stages? Delete.
|
||||
# if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]:
|
||||
# ...
|
||||
for i in range(num_layers, k, -1):
|
||||
stage_cost = compute_cost[k, i, m]
|
||||
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
|
||||
if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
|
||||
f[s, k, d] = new_cost
|
||||
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
|
||||
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
|
||||
best_s = -1
|
||||
best_total_cost = np.inf
|
||||
for s in range(1, num_layers + 1):
|
||||
if f[s, 0, num_devices] < best_total_cost:
|
||||
best_s = s
|
||||
best_total_cost = f[s, 0, num_devices]
|
||||
|
||||
if np.isinf(best_total_cost):
|
||||
return np.inf, None
|
||||
|
||||
total_cost = f[best_s, 0, num_devices] + (num_microbatches - 1) * f_stage_max[best_s, 0, num_devices]
|
||||
current_s = best_s
|
||||
current_layer = 0
|
||||
current_devices = num_devices
|
||||
|
||||
res = []
|
||||
while current_s > 0 and current_layer < num_layers and current_devices > 0:
|
||||
next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
|
||||
assert next_start_layer != -1 and current_devices != -1
|
||||
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
|
||||
current_s -= 1
|
||||
current_layer = next_start_layer
|
||||
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
|
||||
assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
|
||||
|
||||
return total_cost, res
|
||||
|
||||
|
||||
def alpa_dp(num_layers,
|
||||
num_devices,
|
||||
num_microbatches,
|
||||
submesh_choices,
|
||||
num_autosharding_configs,
|
||||
compute_cost,
|
||||
gap=1e-6):
|
||||
"""Alpa auto stage dynamic programming.
|
||||
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
|
||||
|
||||
Arguments:
|
||||
submesh_choices: List[(int,int)]
|
||||
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
|
||||
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
|
||||
"""
|
||||
assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
|
||||
num_autosharding_configs), "Cost shape wrong."
|
||||
all_possible_stage_costs = np.sort(np.unique(compute_cost))
|
||||
best_cost = np.inf
|
||||
best_solution = None
|
||||
last_max_stage_cost = 0.0
|
||||
# TODO: [luzgh]: Why alpa needs the num_autosharding_configs dimension in compute_cost?
|
||||
# In dp_impl it seems the argmin n_config will be chosen. Just amin here.
|
||||
best_configs = np.argmin(compute_cost, axis=3)
|
||||
best_compute_cost = np.amin(compute_cost, axis=3)
|
||||
assert len(all_possible_stage_costs), "no solution in auto stage construction."
|
||||
for max_stage_cost in all_possible_stage_costs:
|
||||
if max_stage_cost * num_microbatches >= best_cost:
|
||||
break
|
||||
if max_stage_cost - last_max_stage_cost < gap:
|
||||
continue
|
||||
cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
|
||||
max_stage_cost, best_configs)
|
||||
if cost < best_cost:
|
||||
best_cost = cost
|
||||
best_solution = solution
|
||||
last_max_stage_cost = max_stage_cost
|
||||
|
||||
return best_cost, best_solution
|
@ -0,0 +1,120 @@
|
||||
import fcntl
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
MB = int((1 << 10) * 1e3)
|
||||
GB = int((1 << 20) * 1e3)
|
||||
Byte = 4
|
||||
FRAMEWORK = 0
|
||||
NON_SENSE = (0.1, 0.1)
|
||||
|
||||
|
||||
def printflock(*msgs):
|
||||
""" solves multi-process interleaved print problem """
|
||||
with open(__file__, "r") as fh:
|
||||
fcntl.flock(fh, fcntl.LOCK_EX)
|
||||
try:
|
||||
print(*msgs)
|
||||
finally:
|
||||
fcntl.flock(fh, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
def profile(device1d, nbytes, ctype):
|
||||
warmup = 5
|
||||
repeat = 25
|
||||
rank = dist.get_rank()
|
||||
src_device_num = device1d[0]
|
||||
wsize = len(device1d)
|
||||
group = dist.new_group(device1d)
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device("cuda", rank)
|
||||
buf = torch.randn(nbytes // 4).to(device)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# warmup
|
||||
for _ in range(warmup):
|
||||
if ctype == "a":
|
||||
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group)
|
||||
elif ctype == "b":
|
||||
dist.broadcast(buf, src=src_device_num, group=group)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
dist.barrier()
|
||||
begin = time.perf_counter()
|
||||
for _ in range(repeat):
|
||||
if ctype == "a":
|
||||
dist.all_reduce(buf, op=dist.ReduceOp.SUM, group=group)
|
||||
elif ctype == "b":
|
||||
dist.broadcast(buf, src=src_device_num, group=group)
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
dist.barrier()
|
||||
|
||||
if rank == src_device_num:
|
||||
avg_time_s = (end - begin) / repeat - FRAMEWORK
|
||||
alg_band = nbytes / avg_time_s
|
||||
if ctype == "b":
|
||||
bus_band = alg_band
|
||||
elif ctype == "a":
|
||||
bus_band = 2 * (wsize - 1) / wsize * alg_band
|
||||
print(
|
||||
f"GPU:{rank}, Bytes: {nbytes} B,Time: {round(avg_time_s * 1e6,2)} us, Bus bandwidth: {round(bus_band / GB,2)} GB/s"
|
||||
)
|
||||
return (avg_time_s, alg_band)
|
||||
else:
|
||||
return NON_SENSE # Just a placeholder
|
||||
|
||||
|
||||
def profile_latency(device1d, it=3, ctype="a"):
|
||||
latency = []
|
||||
for i in range(it):
|
||||
nbytes = int(Byte << i)
|
||||
(t, _) = profile(device1d, nbytes, ctype)
|
||||
latency.append(t)
|
||||
return min(latency)
|
||||
|
||||
|
||||
def profile_bandwidth(device1d, maxbytes, ctype="a"):
|
||||
(_, bandwidth) = profile(device1d, maxbytes, ctype)
|
||||
return bandwidth
|
||||
|
||||
|
||||
def profile_ab(rank, *args):
|
||||
wsize = int(torch.cuda.device_count())
|
||||
device1d = args[0]
|
||||
return_dict = args[1]
|
||||
ctype = args[2]
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '29020'
|
||||
dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://', world_size=wsize, rank=rank)
|
||||
|
||||
device = torch.device("cuda", rank)
|
||||
max_nbytes = torch.tensor(torch.cuda.mem_get_info(device)[0]).to(device)
|
||||
max_nbytes = min(int(4 * GB), int(GB << int(math.log2(max_nbytes.item() / GB))))
|
||||
|
||||
if rank == device1d[0]:
|
||||
print(f"max_nbytes: {max_nbytes} B")
|
||||
|
||||
alpha = profile_latency(device1d, it=5, ctype=ctype)
|
||||
beta = 1 / profile_bandwidth(device1d, maxbytes=max_nbytes, ctype=ctype)
|
||||
|
||||
if rank == device1d[0]:
|
||||
print(f"alpha(us): {round(alpha * 1e6,2)}, beta(us/GB): {round(beta * 1e6 * GB,2)}")
|
||||
return_dict[rank] = (alpha, beta)
|
||||
|
||||
|
||||
def profile_alpha_beta(device1d):
|
||||
assert torch.cuda.is_available()
|
||||
assert len(device1d) > 0 and len(device1d) <= int(torch.cuda.device_count())
|
||||
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
ctype = "a"
|
||||
mp.spawn(profile_ab, args=[device1d, return_dict, ctype], nprocs=int(torch.cuda.device_count()))
|
||||
return return_dict[device1d[0]]
|
@ -0,0 +1,14 @@
|
||||
import pytest
|
||||
|
||||
from colossalai.device import profile_alpha_beta
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skip because assertion fails for CI devices")
|
||||
def test_profile_alpha_beta():
|
||||
physical_devices = [0, 1, 2, 3]
|
||||
(alpha, beta) = profile_alpha_beta(physical_devices)
|
||||
assert alpha > 0 and alpha < 1e-4 and beta > 0 and beta < 1e-10
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_profile_alpha_beta()
|
Loading…
Reference in new issue