mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] Add Simplified Alpa DP Partition (#2507)
* add alpa dp split * add alpa dp split * use fwd+bwd instead of fwd only --------- Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>pull/2987/head
parent
b42d3d28ed
commit
400f63012e
|
@ -1,4 +1,6 @@
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
from torch.fx import symbolic_trace
|
from torch.fx import symbolic_trace
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
@ -9,6 +11,165 @@ def pipe_split():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def block_split():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Construct blocks with the condition that (block_flops / total_flops) >= limit.
|
||||||
|
def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
|
||||||
|
total_fwd_flop = 0
|
||||||
|
total_bwd_flop = 0
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
total_fwd_flop += node.fwd_flop
|
||||||
|
total_bwd_flop += node.bwd_flop
|
||||||
|
|
||||||
|
total_flop = total_fwd_flop + total_bwd_flop
|
||||||
|
per_block_flop = total_flop * limit
|
||||||
|
accumulate_fwd_flop = 0
|
||||||
|
accumulate_bwd_flop = 0
|
||||||
|
block_nodes = []
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
if 'block_split' in node.name:
|
||||||
|
continue
|
||||||
|
accumulate_fwd_flop += node.fwd_flop
|
||||||
|
accumulate_bwd_flop += node.bwd_flop
|
||||||
|
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
|
||||||
|
with gm.graph.inserting_after(node):
|
||||||
|
block_node = gm.graph.create_node('call_function', block_split)
|
||||||
|
setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
|
||||||
|
setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
|
||||||
|
accumulate_fwd_flop = 0
|
||||||
|
accumulate_bwd_flop = 0
|
||||||
|
block_nodes.append(block_node)
|
||||||
|
|
||||||
|
return block_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def remove_blocks(gm: torch.fx.GraphModule):
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
if (node.op, node.target) == ('call_function', block_split):
|
||||||
|
gm.graph.erase_node(node)
|
||||||
|
|
||||||
|
|
||||||
|
def get_compute_costs(node_list):
|
||||||
|
num_nodes = len(node_list)
|
||||||
|
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
|
||||||
|
|
||||||
|
for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
|
||||||
|
for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
|
||||||
|
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
|
||||||
|
all_compute_cost[start, end] = sum(selected_flops)
|
||||||
|
|
||||||
|
return all_compute_cost
|
||||||
|
|
||||||
|
|
||||||
|
def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_costs, max_compute_cost):
|
||||||
|
"""The core implementation of the DP algorithm."""
|
||||||
|
# Adapted from Alpa DP Formulation.
|
||||||
|
# For f, node ID start from 0
|
||||||
|
# f[number of stages,
|
||||||
|
# node id that is currently being considered]
|
||||||
|
|
||||||
|
# record time cost(assess by fwd+bwd flop now)
|
||||||
|
f = np.full((num_stages + 1, num_nodes + 1), np.inf, dtype=np.float32)
|
||||||
|
|
||||||
|
# record max stage compute cost among all stages in this partition.
|
||||||
|
f_stage_max = np.full((num_stages + 1, num_nodes + 1), 0.0, dtype=np.float32)
|
||||||
|
# record start node index for next stage in this partition
|
||||||
|
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
|
||||||
|
f[0, num_nodes] = 0
|
||||||
|
for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
|
||||||
|
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
|
||||||
|
for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
|
||||||
|
stage_cost = compute_costs[i, k - 1]
|
||||||
|
new_cost = f[s - 1, k] + stage_cost
|
||||||
|
if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
|
||||||
|
f[s, i] = new_cost
|
||||||
|
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
|
||||||
|
f_argmin[s, i] = k
|
||||||
|
|
||||||
|
best_total_cost = f[num_stages, 0]
|
||||||
|
if np.isinf(best_total_cost):
|
||||||
|
return np.inf, None
|
||||||
|
|
||||||
|
total_cost = f[num_stages, 0] + (num_microbatches - 1) * f_stage_max[num_stages, 0]
|
||||||
|
|
||||||
|
current_s = num_stages
|
||||||
|
current_node = 0
|
||||||
|
|
||||||
|
res = []
|
||||||
|
while current_s > 0 and current_node < num_nodes:
|
||||||
|
next_start_node = f_argmin[current_s, current_node]
|
||||||
|
res.append((current_node, next_start_node))
|
||||||
|
current_s -= 1
|
||||||
|
current_node = next_start_node
|
||||||
|
|
||||||
|
return total_cost, res
|
||||||
|
|
||||||
|
|
||||||
|
def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatches: int):
|
||||||
|
# Ignore the memory cost profiling in Alpa's design for convenience.
|
||||||
|
max_compute_costs = np.sort(np.unique(compute_costs))
|
||||||
|
best_cost = np.inf
|
||||||
|
best_solution = None
|
||||||
|
last_max_compute_cost = 0.0
|
||||||
|
gap = 1e6 # temporary magic number, unit: flops
|
||||||
|
|
||||||
|
for max_compute_cost in tqdm.tqdm(max_compute_costs):
|
||||||
|
# Pruning to reduce search space.
|
||||||
|
if max_compute_cost * num_microbatches >= best_cost:
|
||||||
|
break
|
||||||
|
if max_compute_cost - last_max_compute_cost < gap:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
|
||||||
|
max_compute_cost)
|
||||||
|
|
||||||
|
if cost < best_cost:
|
||||||
|
best_cost = cost
|
||||||
|
best_solution = solution
|
||||||
|
last_max_compute_cost = max_compute_cost
|
||||||
|
return best_cost, best_solution
|
||||||
|
|
||||||
|
|
||||||
|
# Auto DP partition based on Alpa.
|
||||||
|
# Adapted to Gpipe Scheduler
|
||||||
|
# split_mode:
|
||||||
|
# 'node': fx_node
|
||||||
|
# 'block': many fx_nodes construct a block
|
||||||
|
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
|
||||||
|
assert mode in ['node', 'block']
|
||||||
|
|
||||||
|
# nodes or blocks will be used in partition.
|
||||||
|
node_list = []
|
||||||
|
if mode == 'node':
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
node_list.append(node)
|
||||||
|
elif mode == 'block':
|
||||||
|
node_list = construct_blocks(gm, limit=block_limit)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
compute_costs = get_compute_costs(node_list)
|
||||||
|
|
||||||
|
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
|
||||||
|
|
||||||
|
for (_, next_start_node) in best_solution:
|
||||||
|
if pp_size <= 1:
|
||||||
|
break
|
||||||
|
node = node_list[next_start_node]
|
||||||
|
with gm.graph.inserting_before(node):
|
||||||
|
split_node = gm.graph.create_node('call_function', pipe_split)
|
||||||
|
pp_size -= 1
|
||||||
|
|
||||||
|
# remove block node if possible
|
||||||
|
if mode == 'block':
|
||||||
|
remove_blocks(gm)
|
||||||
|
|
||||||
|
gm.recompile()
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||||
"""
|
"""
|
||||||
In avgcompute_split_pass, we split module by the fwd flops.
|
In avgcompute_split_pass, we split module by the fwd flops.
|
||||||
|
|
|
@ -114,6 +114,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||||
# TODO: the attribute node_size should be removed in the future
|
# TODO: the attribute node_size should be removed in the future
|
||||||
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
|
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
|
||||||
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
|
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
|
||||||
|
setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
|
||||||
n.meta['type'] = type(result)
|
n.meta['type'] = type(result)
|
||||||
|
|
||||||
# retain the autograd graph
|
# retain the autograd graph
|
||||||
|
|
|
@ -1115,7 +1115,8 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||||
# let each worker know global worker rref (include itself)
|
# let each worker know global worker rref (include itself)
|
||||||
sync_futs = []
|
sync_futs = []
|
||||||
for pp_rank in self.pp_rank_to_worker_rref:
|
for pp_rank in self.pp_rank_to_worker_rref:
|
||||||
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
|
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs(
|
||||||
|
self.pp_rank_to_worker_rref)
|
||||||
sync_futs.append(fut)
|
sync_futs.append(fut)
|
||||||
|
|
||||||
for fut in sync_futs:
|
for fut in sync_futs:
|
||||||
|
|
|
@ -8,11 +8,16 @@ from torch import nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass
|
from colossalai.fx.passes.adding_split_node_pass import (
|
||||||
|
avgnode_split_pass,
|
||||||
|
gpipe_dp_split_pass,
|
||||||
|
split_with_split_nodes_pass,
|
||||||
|
)
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||||
from colossalai.pipeline.rpc.utils import rpc_run
|
from colossalai.pipeline.rpc.utils import rpc_run
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,13 +60,25 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||||
|
|
||||||
|
|
||||||
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
# Create annotated model which is noted where to be splitted.
|
||||||
|
def get_annotated_model(model, data_kwargs, num_stages, num_microbatches):
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
|
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
|
||||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||||
annotated_model = avgnode_split_pass(gm, stage_num)
|
|
||||||
|
|
||||||
|
interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()])
|
||||||
|
interp = MetaInfoProp(gm)
|
||||||
|
interp.run(*interp_meta_args)
|
||||||
|
|
||||||
|
#annotated_model = avgnode_split_pass(gm, num_stages)
|
||||||
|
annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01)
|
||||||
|
|
||||||
|
return annotated_model
|
||||||
|
|
||||||
|
|
||||||
|
def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, num_microbatches):
|
||||||
|
annotated_model = get_annotated_model(model, data_kwargs, num_stages, num_microbatches)
|
||||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||||
topo = get_fx_topology(top_module)
|
topo = get_fx_topology(top_module)
|
||||||
for submodule in split_submodules:
|
for submodule in split_submodules:
|
||||||
|
@ -70,8 +87,8 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
||||||
return split_submodules[pp_rank + 1]
|
return split_submodules[pp_rank + 1]
|
||||||
|
|
||||||
|
|
||||||
def partition(model, data_kwargs, pp_rank: int, chunk: int, stage_num: int):
|
def partition(model, data_kwargs, num_microbatches, pp_rank: int, chunk: int, stage_num: int):
|
||||||
module = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
module = create_partition_module(pp_rank, stage_num, model, data_kwargs, num_microbatches)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,17 +120,19 @@ def run_master(args):
|
||||||
warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask}
|
warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
|
logger.info(f'start model_builder')
|
||||||
model = model_builder(model_type)(checkpoint=False)
|
model = model_builder(model_type)(checkpoint=False)
|
||||||
|
logger.info(f'end model_builder')
|
||||||
|
|
||||||
# set 1f1b pipeline engine
|
# set 1f1b pipeline engine
|
||||||
pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs),
|
pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches),
|
||||||
stage_num=stage_num,
|
stage_num=stage_num,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
device=device,
|
device=device,
|
||||||
chunk=1,
|
chunk=1,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
metric=None,
|
metric=None,
|
||||||
checkpoint=False)
|
checkpoint=False)
|
||||||
|
|
||||||
partition_numels = pp_engine.remote_numels()
|
partition_numels = pp_engine.remote_numels()
|
||||||
for rank, numel in partition_numels.items():
|
for rank, numel in partition_numels.items():
|
||||||
|
|
Loading…
Reference in New Issue