mirror of https://github.com/hpcaitech/ColossalAI
[Pipeline Middleware] Adapt scheduler for Topo (#2066)
* adapt scheduler for Topo * remoove comment * fix set input Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>pull/1957/head
parent
b3b89865e2
commit
597cdd3006
|
@ -108,7 +108,7 @@ def get_topology(gm: GraphModule):
|
|||
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
|
||||
topo_input_partition.add_output_val(p_output_val)
|
||||
topo.set_partitions(partition_id=0, partition=topo_input_partition)
|
||||
topo.set_input_partition(partition_id=0)
|
||||
topo.set_input_partition_id(partition_id=0)
|
||||
|
||||
for i, partition in enumerate(partitions):
|
||||
topo_mid_partition = Partition()
|
||||
|
@ -140,6 +140,6 @@ def get_topology(gm: GraphModule):
|
|||
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
|
||||
find_input_in_partition(n, partitions, input_partitions)))
|
||||
topo.set_partitions(partition_id=1, partition=topo_output_partition)
|
||||
topo.set_output_partition(partition_id=1)
|
||||
topo.set_output_partition_id(partition_id=1)
|
||||
|
||||
return topo
|
|
@ -72,6 +72,36 @@ class Partition(object):
|
|||
def get_output_vals(self):
|
||||
return self._output_vals
|
||||
|
||||
# get the output offsets sent to dst_partition_id
|
||||
def get_output_offsets(self, dst_partition_id):
|
||||
res = []
|
||||
for offset, output_val in enumerate(self._output_vals):
|
||||
outputs = output_val.get()
|
||||
for val_pos in outputs:
|
||||
if val_pos.partition_id == dst_partition_id:
|
||||
res.append(offset)
|
||||
|
||||
return res
|
||||
|
||||
# get all input dst partition_ids
|
||||
def get_input_partition_ids(self):
|
||||
res = []
|
||||
for input_val in self._input_vals:
|
||||
val_pos = input_val.get()
|
||||
if val_pos.partition_id not in res:
|
||||
res.append(val_pos.partition_id)
|
||||
return res
|
||||
|
||||
# get all output dst partition_ids
|
||||
def get_output_partition_ids(self):
|
||||
res = []
|
||||
for output_val in self._output_vals:
|
||||
outputs = output_val.get()
|
||||
for val_pos in outputs:
|
||||
if val_pos.partition_id not in res:
|
||||
res.append(val_pos.partition_id)
|
||||
return res
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f' input:\n'
|
||||
|
@ -107,12 +137,18 @@ class Topo(object):
|
|||
self._input_partition_id = input_partition_id
|
||||
self._output_partition_id = output_partition_id
|
||||
|
||||
def set_input_partition(self, partition_id: int):
|
||||
def set_input_partition_id(self, partition_id: int):
|
||||
self._input_partition_id = partition_id
|
||||
|
||||
def set_output_partition(self, partition_id: int):
|
||||
def set_output_partition_id(self, partition_id: int):
|
||||
self._output_partition_id = partition_id
|
||||
|
||||
def get_input_partition_id(self):
|
||||
return self._input_partition_id
|
||||
|
||||
def get_output_partition_id(self):
|
||||
return self._output_partition_id
|
||||
|
||||
def set_partitions(self, partition_id: int, partition: Partition):
|
||||
self._partitions[partition_id] = partition
|
||||
|
||||
|
@ -124,6 +160,9 @@ class Topo(object):
|
|||
res[partition_id] = partition
|
||||
return res
|
||||
|
||||
def get_mid_partition_ids(self):
|
||||
return list(self.get_mid_partitions().keys())
|
||||
|
||||
def get_input_partition(self):
|
||||
if self._input_partition_id is not None:
|
||||
return self._partitions[self._input_partition_id]
|
||||
|
@ -134,6 +173,9 @@ class Topo(object):
|
|||
return self._partitions[self._output_partition_id]
|
||||
return None
|
||||
|
||||
def get_partition_by_id(self, partition_id):
|
||||
return self._partitions[partition_id]
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
if len(self._partitions) == 0:
|
||||
|
|
|
@ -11,6 +11,7 @@ import torch.distributed.rpc as rpc
|
|||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
|
||||
split_batch, tensor_shape_list, type_detail)
|
||||
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from torch import autograd, nn, optim
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
@ -128,7 +129,6 @@ class WorkerBase(ABC):
|
|||
# topology info
|
||||
self.producer_stage_ids: List[int] = None
|
||||
self.consumer_stage_ids: List[int] = None
|
||||
self.input_consumer_stage_ids: List[int] = None
|
||||
|
||||
# module partitions
|
||||
self.partition_fn = partition_fn
|
||||
|
@ -137,9 +137,7 @@ class WorkerBase(ABC):
|
|||
self.metric = metric
|
||||
|
||||
# middleware info
|
||||
self._is_input = False
|
||||
self._is_output = False
|
||||
self._producer_consumer_initialized = False
|
||||
|
||||
# context to maintain loop
|
||||
self._initialize_context_container()
|
||||
|
@ -170,7 +168,6 @@ class WorkerBase(ABC):
|
|||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.label_lock = threading.Condition(threading.Lock())
|
||||
self.producer_consumer_init_lock = threading.Condition(threading.Lock())
|
||||
|
||||
def _initialize_partition(self):
|
||||
partition_fn = self.partition_fn
|
||||
|
@ -207,6 +204,7 @@ class WorkerBase(ABC):
|
|||
self.output_list.pop(key)
|
||||
return output
|
||||
|
||||
|
||||
def get_parameters(self) -> List[torch.Tensor]:
|
||||
return [p for p in self.module_partition.parameters()]
|
||||
|
||||
|
@ -251,7 +249,6 @@ class WorkerBase(ABC):
|
|||
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
||||
assert self.consumer_stage_ids is not None
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
output = self._get_future_by_device()
|
||||
|
||||
|
@ -269,20 +266,11 @@ class WorkerBase(ABC):
|
|||
arg_lst, _ = self._make_args_kwargs(microbatch, merge=True)
|
||||
|
||||
# first stage assign correct input into other stages
|
||||
DAG = self.get_DAG()
|
||||
DAG_node = DAG['input_partition']
|
||||
self_input_offsets = []
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
input_partition = topo.get_input_partition()
|
||||
self_input_offsets = input_partition.get_output_offsets(self_partition_id)
|
||||
recv_input_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||
# notify rank which should receive extra input
|
||||
offset = 0
|
||||
for details in DAG_node.values():
|
||||
for partition_name in details['output'].keys():
|
||||
recv_rank = self.partition_name_to_pp_rank(partition_name)
|
||||
if recv_rank == self.pp_rank:
|
||||
self_input_offsets.append(offset)
|
||||
elif recv_rank not in self.input_consumer_stage_ids:
|
||||
self.input_consumer_stage_ids.append(recv_rank)
|
||||
offset += 1
|
||||
|
||||
# set input for self rank
|
||||
self_arg_lst = []
|
||||
|
@ -295,7 +283,7 @@ class WorkerBase(ABC):
|
|||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# put input tensor which other nodes need into output_list
|
||||
# put input tensor which other nodes need into output_list as Phase.INPUT
|
||||
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
|
||||
|
@ -344,16 +332,10 @@ class WorkerBase(ABC):
|
|||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
if key in self.work_list:
|
||||
return
|
||||
|
||||
producer_stage_ids = []
|
||||
with self.producer_consumer_init_lock:
|
||||
self.producer_consumer_init_lock.wait_for(lambda: self._producer_consumer_initialized)
|
||||
producer_stage_ids = self.producer_stage_ids
|
||||
producer_stage_ids = self.get_producer_stage_ids()
|
||||
producer_num = len(producer_stage_ids)
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
if self.need_model_input():
|
||||
producer_num += 1 # extra one(the last one) for input_tensor
|
||||
producer_num += 1 # for input partition
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
|
@ -374,7 +356,6 @@ class WorkerBase(ABC):
|
|||
producer_stage_id = producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
#producer_partition_name = self.pp_rank_to_partition_name[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||
|
@ -416,70 +397,76 @@ class WorkerBase(ABC):
|
|||
self.work_list[key] = work_item_from_consumer
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def get_producer_stage_ids(self):
|
||||
producer_stage_ids = []
|
||||
rank = self.pp_rank
|
||||
if not self.use_middleware():
|
||||
prev_rank = rank - 1
|
||||
if prev_rank >= 0:
|
||||
producer_stage_ids.append(prev_rank)
|
||||
else:
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(rank, topo)
|
||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||
input_partition_ids = self_partition.get_input_partition_ids()
|
||||
model_input_partition_id = topo.get_input_partition_id()
|
||||
for partition_id in input_partition_ids:
|
||||
# ignore input partition in current implementation.
|
||||
# it will be specially tackled.
|
||||
if partition_id != model_input_partition_id:
|
||||
producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
|
||||
return producer_stage_ids
|
||||
|
||||
def get_consumer_stage_ids(self):
|
||||
consumer_stage_ids = []
|
||||
rank = self.pp_rank
|
||||
if not self.use_middleware():
|
||||
next_rank = rank + 1
|
||||
if next_rank <= self.actual_stage_num - 1:
|
||||
consumer_stage_ids.append(next_rank)
|
||||
else:
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(rank, topo)
|
||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||
output_partition_ids = self_partition.get_output_partition_ids()
|
||||
model_output_partition_id = topo.get_output_partition_id()
|
||||
for partition_id in output_partition_ids:
|
||||
if model_output_partition_id != partition_id:
|
||||
consumer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
|
||||
return consumer_stage_ids
|
||||
|
||||
def _get_producer_consumer(self) -> None:
|
||||
rank = self.pp_rank
|
||||
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
|
||||
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
|
||||
|
||||
# should be aranged in order, the order of the input of current forward
|
||||
self.producer_stage_ids = []
|
||||
self.consumer_stage_ids = []
|
||||
|
||||
if not self.use_middleware():
|
||||
# Just for demo
|
||||
prev_rank = rank - 1
|
||||
next_rank = rank + 1
|
||||
if prev_rank >= 0:
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
if next_rank <= self.actual_stage_num - 1:
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
else:
|
||||
self.input_consumer_stage_ids = []
|
||||
DAG = self.get_DAG()
|
||||
DAG_node_name = self.pp_rank_to_partition_name(rank)
|
||||
DAG_node = DAG[DAG_node_name]
|
||||
for partition_name in DAG_node['input'].keys():
|
||||
if partition_name == 'MODEL_INPUT':
|
||||
self._is_input = True
|
||||
else:
|
||||
prev_rank = self.partition_name_to_pp_rank(partition_name)
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
|
||||
for partition_name in DAG_node['output'].keys():
|
||||
if partition_name == 'MODEL_OUTPUT':
|
||||
self._is_output = True
|
||||
else:
|
||||
next_rank = self.partition_name_to_pp_rank(partition_name)
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
|
||||
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||
with self.producer_consumer_init_lock:
|
||||
self._producer_consumer_initialized = True
|
||||
self.producer_consumer_init_lock.notify_all()
|
||||
self.producer_stage_ids = self.get_producer_stage_ids()
|
||||
self.consumer_stage_ids = self.get_consumer_stage_ids()
|
||||
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def pp_rank_to_partition_name(self, pp_rank: int):
|
||||
prefix = 'submod_'
|
||||
partition_name = prefix + str(pp_rank)
|
||||
return partition_name
|
||||
def pp_rank_to_partition_id(self, pp_rank: int, topo: Topo):
|
||||
partition_ids = topo.get_mid_partition_ids()
|
||||
return partition_ids[pp_rank]
|
||||
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def partition_name_to_pp_rank(self, partition_name: str) -> int:
|
||||
prefix = 'submod_'
|
||||
pp_rank = int(partition_name.split(prefix)[-1])
|
||||
return pp_rank
|
||||
def partition_id_to_pp_rank(self, partition_id: int, topo: Topo):
|
||||
partition_ids = topo.get_mid_partition_ids()
|
||||
for i, id in enumerate(partition_ids):
|
||||
if id == partition_id:
|
||||
return i
|
||||
|
||||
def get_DAG(self):
|
||||
def get_topo(self):
|
||||
with self.partition_condition_lock:
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
if hasattr(self.module_partition, '_DAG'):
|
||||
return self.module_partition._DAG
|
||||
if hasattr(self.module_partition, '_topo'):
|
||||
return self.module_partition._topo
|
||||
else:
|
||||
return None
|
||||
|
||||
def use_middleware(self):
|
||||
DAG = self.get_DAG()
|
||||
return DAG is not None
|
||||
topo = self.get_topo()
|
||||
return topo is not None
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
def _get_real_args_kwargs(self, args_or_kwargs):
|
||||
|
@ -503,52 +490,42 @@ class WorkerBase(ABC):
|
|||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
# TODO get by offset
|
||||
else:
|
||||
DAG = self.get_DAG()
|
||||
producer_outputs = {}
|
||||
cur_DAG_node_name = self.pp_rank_to_partition_name(self.pp_rank)
|
||||
#cur_DAG_node = DAG[self.pp_rank_to_partition_name(self.pp_rank)]
|
||||
for i, args_from_one_mod in enumerate(args_or_kwargs):
|
||||
producer_output_offsets = []
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||
model_input_partition_id = topo.get_input_partition_id()
|
||||
input_vals = self_partition.get_input_vals()
|
||||
producer_stage_ids = self.get_producer_stage_ids()
|
||||
if self.need_model_input():
|
||||
if i == 0:
|
||||
producer_DAG_node = DAG['input_partition']
|
||||
producer_partition_name = 'MODEL_INPUT'
|
||||
offset = 0
|
||||
for arg_info in producer_DAG_node.values():
|
||||
if cur_DAG_node_name in arg_info['output']:
|
||||
producer_output_offsets.append(offset)
|
||||
offset += 1
|
||||
# 0 for data from input batch
|
||||
# >= 1 for data from prev stages
|
||||
base = 1
|
||||
else:
|
||||
producer_rank = self.producer_stage_ids[i-1]
|
||||
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
|
||||
producer_DAG_node = DAG[producer_partition_name]
|
||||
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
|
||||
|
||||
# data from prev stages
|
||||
base = 0
|
||||
for val in input_vals:
|
||||
val_pos = val.get()
|
||||
src_partition_id = val_pos.partition_id
|
||||
src_offset = val_pos.offset
|
||||
src_index = base
|
||||
src_partition = topo.get_partition_by_id(src_partition_id)
|
||||
output_len = len(src_partition.get_output_vals())
|
||||
# data from not-input partition
|
||||
if src_partition_id != model_input_partition_id:
|
||||
src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo)
|
||||
src_index = base
|
||||
for i, stage_id in enumerate(producer_stage_ids):
|
||||
if stage_id == src_stage_id:
|
||||
src_index += i
|
||||
break
|
||||
else: # data from input partition
|
||||
src_index = 0
|
||||
# when output_len = 1, not iterable
|
||||
if output_len == 1:
|
||||
target = args_or_kwargs[src_index]
|
||||
else:
|
||||
producer_rank = self.producer_stage_ids[i]
|
||||
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
|
||||
producer_DAG_node = DAG[producer_partition_name]
|
||||
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
|
||||
|
||||
if producer_partition_name != 'MODEL_INPUT' and DAG[producer_partition_name]['output_len'] == 1:
|
||||
producer_outputs[producer_partition_name] = [args_from_one_mod]
|
||||
else:
|
||||
producer_outputs[producer_partition_name] = [args_from_one_mod[offset] for offset in producer_output_offsets]
|
||||
|
||||
cur_DAG_node_input = DAG[cur_DAG_node_name]['input']
|
||||
|
||||
def get_input_len(DAG_node_input):
|
||||
res = 0
|
||||
for offsets in DAG_node_input.values():
|
||||
res += len(offsets)
|
||||
return res
|
||||
|
||||
input_len = get_input_len(cur_DAG_node_input)
|
||||
flatten_args = [None] * input_len
|
||||
for producer_partition_name, args_input_offsets in cur_DAG_node_input.items():
|
||||
for i, args_input_offset in enumerate(args_input_offsets):
|
||||
flatten_args[args_input_offset] = producer_outputs[producer_partition_name][i]
|
||||
|
||||
target = args_or_kwargs[src_index][src_offset]
|
||||
flatten_args.append(target)
|
||||
args_or_kwargs = flatten_args
|
||||
return args_or_kwargs
|
||||
|
||||
|
@ -565,7 +542,15 @@ class WorkerBase(ABC):
|
|||
return self.pp_rank == self.actual_stage_num - 1
|
||||
|
||||
def need_model_input(self):
|
||||
return not self.is_first_stage() and self._is_input
|
||||
need_input = False
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
self_partition = topo.get_partition_by_id(self_partition_id)
|
||||
partition_inputs = self_partition.get_input_partition_ids()
|
||||
model_input_partition_id = topo.get_input_partition_id()
|
||||
if model_input_partition_id in partition_inputs:
|
||||
need_input = True
|
||||
return not self.is_first_stage() and need_input
|
||||
|
||||
def _default_data_process_func(self, args_kwargs):
|
||||
if self.is_first_stage():
|
||||
|
|
|
@ -4,6 +4,7 @@ from torch import nn
|
|||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||
from rpc_test_utils import rpc_run, parse_args, MLP
|
||||
from functools import partial
|
||||
|
||||
|
@ -18,8 +19,12 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
|||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||
annotated_model = balanced_split_pass(gm, stage_num)
|
||||
split_model, _ = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||
return list(split_model.children())[pp_rank]
|
||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
setattr(submodule, '_topo', topo)
|
||||
return split_submodules[pp_rank+1]
|
||||
|
||||
def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
|
||||
torch.manual_seed(1024)
|
||||
|
|
Loading…
Reference in New Issue