mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] Add Simplified Alpa DP Partition (#2507)
* add alpa dp split * add alpa dp split * use fwd+bwd instead of fwd only --------- Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>pull/2987/head
parent
b42d3d28ed
commit
400f63012e
|
@ -1,4 +1,6 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
|
||||
|
@ -9,6 +11,165 @@ def pipe_split():
|
|||
pass
|
||||
|
||||
|
||||
def block_split():
|
||||
pass
|
||||
|
||||
|
||||
# Construct blocks with the condition that (block_flops / total_flops) >= limit.
|
||||
def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
|
||||
total_fwd_flop = 0
|
||||
total_bwd_flop = 0
|
||||
for node in gm.graph.nodes:
|
||||
total_fwd_flop += node.fwd_flop
|
||||
total_bwd_flop += node.bwd_flop
|
||||
|
||||
total_flop = total_fwd_flop + total_bwd_flop
|
||||
per_block_flop = total_flop * limit
|
||||
accumulate_fwd_flop = 0
|
||||
accumulate_bwd_flop = 0
|
||||
block_nodes = []
|
||||
for node in gm.graph.nodes:
|
||||
if 'block_split' in node.name:
|
||||
continue
|
||||
accumulate_fwd_flop += node.fwd_flop
|
||||
accumulate_bwd_flop += node.bwd_flop
|
||||
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
|
||||
with gm.graph.inserting_after(node):
|
||||
block_node = gm.graph.create_node('call_function', block_split)
|
||||
setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
|
||||
setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
|
||||
accumulate_fwd_flop = 0
|
||||
accumulate_bwd_flop = 0
|
||||
block_nodes.append(block_node)
|
||||
|
||||
return block_nodes
|
||||
|
||||
|
||||
def remove_blocks(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
if (node.op, node.target) == ('call_function', block_split):
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
|
||||
def get_compute_costs(node_list):
|
||||
num_nodes = len(node_list)
|
||||
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
|
||||
|
||||
for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
|
||||
for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
|
||||
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
|
||||
all_compute_cost[start, end] = sum(selected_flops)
|
||||
|
||||
return all_compute_cost
|
||||
|
||||
|
||||
def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_costs, max_compute_cost):
|
||||
"""The core implementation of the DP algorithm."""
|
||||
# Adapted from Alpa DP Formulation.
|
||||
# For f, node ID start from 0
|
||||
# f[number of stages,
|
||||
# node id that is currently being considered]
|
||||
|
||||
# record time cost(assess by fwd+bwd flop now)
|
||||
f = np.full((num_stages + 1, num_nodes + 1), np.inf, dtype=np.float32)
|
||||
|
||||
# record max stage compute cost among all stages in this partition.
|
||||
f_stage_max = np.full((num_stages + 1, num_nodes + 1), 0.0, dtype=np.float32)
|
||||
# record start node index for next stage in this partition
|
||||
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
|
||||
f[0, num_nodes] = 0
|
||||
for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
|
||||
for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
|
||||
for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
|
||||
stage_cost = compute_costs[i, k - 1]
|
||||
new_cost = f[s - 1, k] + stage_cost
|
||||
if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
|
||||
f[s, i] = new_cost
|
||||
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
|
||||
f_argmin[s, i] = k
|
||||
|
||||
best_total_cost = f[num_stages, 0]
|
||||
if np.isinf(best_total_cost):
|
||||
return np.inf, None
|
||||
|
||||
total_cost = f[num_stages, 0] + (num_microbatches - 1) * f_stage_max[num_stages, 0]
|
||||
|
||||
current_s = num_stages
|
||||
current_node = 0
|
||||
|
||||
res = []
|
||||
while current_s > 0 and current_node < num_nodes:
|
||||
next_start_node = f_argmin[current_s, current_node]
|
||||
res.append((current_node, next_start_node))
|
||||
current_s -= 1
|
||||
current_node = next_start_node
|
||||
|
||||
return total_cost, res
|
||||
|
||||
|
||||
def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatches: int):
|
||||
# Ignore the memory cost profiling in Alpa's design for convenience.
|
||||
max_compute_costs = np.sort(np.unique(compute_costs))
|
||||
best_cost = np.inf
|
||||
best_solution = None
|
||||
last_max_compute_cost = 0.0
|
||||
gap = 1e6 # temporary magic number, unit: flops
|
||||
|
||||
for max_compute_cost in tqdm.tqdm(max_compute_costs):
|
||||
# Pruning to reduce search space.
|
||||
if max_compute_cost * num_microbatches >= best_cost:
|
||||
break
|
||||
if max_compute_cost - last_max_compute_cost < gap:
|
||||
continue
|
||||
|
||||
cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
|
||||
max_compute_cost)
|
||||
|
||||
if cost < best_cost:
|
||||
best_cost = cost
|
||||
best_solution = solution
|
||||
last_max_compute_cost = max_compute_cost
|
||||
return best_cost, best_solution
|
||||
|
||||
|
||||
# Auto DP partition based on Alpa.
|
||||
# Adapted to Gpipe Scheduler
|
||||
# split_mode:
|
||||
# 'node': fx_node
|
||||
# 'block': many fx_nodes construct a block
|
||||
def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
|
||||
assert mode in ['node', 'block']
|
||||
|
||||
# nodes or blocks will be used in partition.
|
||||
node_list = []
|
||||
if mode == 'node':
|
||||
for node in gm.graph.nodes:
|
||||
node_list.append(node)
|
||||
elif mode == 'block':
|
||||
node_list = construct_blocks(gm, limit=block_limit)
|
||||
else:
|
||||
pass
|
||||
|
||||
compute_costs = get_compute_costs(node_list)
|
||||
|
||||
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
|
||||
|
||||
for (_, next_start_node) in best_solution:
|
||||
if pp_size <= 1:
|
||||
break
|
||||
node = node_list[next_start_node]
|
||||
with gm.graph.inserting_before(node):
|
||||
split_node = gm.graph.create_node('call_function', pipe_split)
|
||||
pp_size -= 1
|
||||
|
||||
# remove block node if possible
|
||||
if mode == 'block':
|
||||
remove_blocks(gm)
|
||||
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
"""
|
||||
In avgcompute_split_pass, we split module by the fwd flops.
|
||||
|
|
|
@ -114,6 +114,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
|
||||
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
|
||||
setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
|
||||
n.meta['type'] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
|
|
|
@ -1115,7 +1115,8 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
# let each worker know global worker rref (include itself)
|
||||
sync_futs = []
|
||||
for pp_rank in self.pp_rank_to_worker_rref:
|
||||
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
|
||||
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs(
|
||||
self.pp_rank_to_worker_rref)
|
||||
sync_futs.append(fut)
|
||||
|
||||
for fut in sync_futs:
|
||||
|
|
|
@ -8,11 +8,16 @@ from torch import nn
|
|||
from tqdm import tqdm
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass
|
||||
from colossalai.fx.passes.adding_split_node_pass import (
|
||||
avgnode_split_pass,
|
||||
gpipe_dp_split_pass,
|
||||
split_with_split_nodes_pass,
|
||||
)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.pipeline.rpc.utils import rpc_run
|
||||
|
||||
|
||||
|
@ -55,13 +60,25 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
|
|||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||
|
||||
|
||||
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
||||
# Create annotated model which is noted where to be splitted.
|
||||
def get_annotated_model(model, data_kwargs, num_stages, num_microbatches):
|
||||
tracer = ColoTracer()
|
||||
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||
annotated_model = avgnode_split_pass(gm, stage_num)
|
||||
|
||||
interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()])
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.run(*interp_meta_args)
|
||||
|
||||
#annotated_model = avgnode_split_pass(gm, num_stages)
|
||||
annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01)
|
||||
|
||||
return annotated_model
|
||||
|
||||
|
||||
def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, num_microbatches):
|
||||
annotated_model = get_annotated_model(model, data_kwargs, num_stages, num_microbatches)
|
||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
|
@ -70,8 +87,8 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
|||
return split_submodules[pp_rank + 1]
|
||||
|
||||
|
||||
def partition(model, data_kwargs, pp_rank: int, chunk: int, stage_num: int):
|
||||
module = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
||||
def partition(model, data_kwargs, num_microbatches, pp_rank: int, chunk: int, stage_num: int):
|
||||
module = create_partition_module(pp_rank, stage_num, model, data_kwargs, num_microbatches)
|
||||
return module
|
||||
|
||||
|
||||
|
@ -103,17 +120,19 @@ def run_master(args):
|
|||
warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||
|
||||
# create model
|
||||
logger.info(f'start model_builder')
|
||||
model = model_builder(model_type)(checkpoint=False)
|
||||
logger.info(f'end model_builder')
|
||||
|
||||
# set 1f1b pipeline engine
|
||||
pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=1,
|
||||
criterion=criterion,
|
||||
metric=None,
|
||||
checkpoint=False)
|
||||
pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=1,
|
||||
criterion=criterion,
|
||||
metric=None,
|
||||
checkpoint=False)
|
||||
|
||||
partition_numels = pp_engine.remote_numels()
|
||||
for rank, numel in partition_numels.items():
|
||||
|
|
Loading…
Reference in New Issue