[Pipeline Middleware] Adapt scheduler for Topo (#2066)

* adapt scheduler for Topo

* remoove comment

* fix set input

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
pull/1957/head
Ziyue Jiang 2022-12-05 20:23:41 +08:00 committed by GitHub
parent b3b89865e2
commit 597cdd3006
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 160 additions and 128 deletions

View File

@ -108,7 +108,7 @@ def get_topology(gm: GraphModule):
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_input_partition.add_output_val(p_output_val) topo_input_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=0, partition=topo_input_partition) topo.set_partitions(partition_id=0, partition=topo_input_partition)
topo.set_input_partition(partition_id=0) topo.set_input_partition_id(partition_id=0)
for i, partition in enumerate(partitions): for i, partition in enumerate(partitions):
topo_mid_partition = Partition() topo_mid_partition = Partition()
@ -140,6 +140,6 @@ def get_topology(gm: GraphModule):
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val( torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
find_input_in_partition(n, partitions, input_partitions))) find_input_in_partition(n, partitions, input_partitions)))
topo.set_partitions(partition_id=1, partition=topo_output_partition) topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition(partition_id=1) topo.set_output_partition_id(partition_id=1)
return topo return topo

View File

@ -71,6 +71,36 @@ class Partition(object):
def get_output_vals(self): def get_output_vals(self):
return self._output_vals return self._output_vals
# get the output offsets sent to dst_partition_id
def get_output_offsets(self, dst_partition_id):
res = []
for offset, output_val in enumerate(self._output_vals):
outputs = output_val.get()
for val_pos in outputs:
if val_pos.partition_id == dst_partition_id:
res.append(offset)
return res
# get all input dst partition_ids
def get_input_partition_ids(self):
res = []
for input_val in self._input_vals:
val_pos = input_val.get()
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
# get all output dst partition_ids
def get_output_partition_ids(self):
res = []
for output_val in self._output_vals:
outputs = output_val.get()
for val_pos in outputs:
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
def __str__(self) -> str: def __str__(self) -> str:
res = '' res = ''
@ -107,11 +137,17 @@ class Topo(object):
self._input_partition_id = input_partition_id self._input_partition_id = input_partition_id
self._output_partition_id = output_partition_id self._output_partition_id = output_partition_id
def set_input_partition(self, partition_id: int): def set_input_partition_id(self, partition_id: int):
self._input_partition_id = partition_id self._input_partition_id = partition_id
def set_output_partition(self, partition_id: int): def set_output_partition_id(self, partition_id: int):
self._output_partition_id = partition_id self._output_partition_id = partition_id
def get_input_partition_id(self):
return self._input_partition_id
def get_output_partition_id(self):
return self._output_partition_id
def set_partitions(self, partition_id: int, partition: Partition): def set_partitions(self, partition_id: int, partition: Partition):
self._partitions[partition_id] = partition self._partitions[partition_id] = partition
@ -124,6 +160,9 @@ class Topo(object):
res[partition_id] = partition res[partition_id] = partition
return res return res
def get_mid_partition_ids(self):
return list(self.get_mid_partitions().keys())
def get_input_partition(self): def get_input_partition(self):
if self._input_partition_id is not None: if self._input_partition_id is not None:
return self._partitions[self._input_partition_id] return self._partitions[self._input_partition_id]
@ -133,6 +172,9 @@ class Topo(object):
if self._output_partition_id is not None: if self._output_partition_id is not None:
return self._partitions[self._output_partition_id] return self._partitions[self._output_partition_id]
return None return None
def get_partition_by_id(self, partition_id):
return self._partitions[partition_id]
def __str__(self) -> str: def __str__(self) -> str:
res = '' res = ''

View File

@ -11,6 +11,7 @@ import torch.distributed.rpc as rpc
from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map, from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
split_batch, tensor_shape_list, type_detail) split_batch, tensor_shape_list, type_detail)
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from torch import autograd, nn, optim from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef from torch._C._distributed_rpc import PyRRef
from torch.futures import Future from torch.futures import Future
@ -128,7 +129,6 @@ class WorkerBase(ABC):
# topology info # topology info
self.producer_stage_ids: List[int] = None self.producer_stage_ids: List[int] = None
self.consumer_stage_ids: List[int] = None self.consumer_stage_ids: List[int] = None
self.input_consumer_stage_ids: List[int] = None
# module partitions # module partitions
self.partition_fn = partition_fn self.partition_fn = partition_fn
@ -137,9 +137,7 @@ class WorkerBase(ABC):
self.metric = metric self.metric = metric
# middleware info # middleware info
self._is_input = False
self._is_output = False self._is_output = False
self._producer_consumer_initialized = False
# context to maintain loop # context to maintain loop
self._initialize_context_container() self._initialize_context_container()
@ -170,7 +168,6 @@ class WorkerBase(ABC):
self.work_list_condition_lock = threading.Condition(threading.Lock()) self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock()) self.label_lock = threading.Condition(threading.Lock())
self.producer_consumer_init_lock = threading.Condition(threading.Lock())
def _initialize_partition(self): def _initialize_partition(self):
partition_fn = self.partition_fn partition_fn = self.partition_fn
@ -207,6 +204,7 @@ class WorkerBase(ABC):
self.output_list.pop(key) self.output_list.pop(key)
return output return output
def get_parameters(self) -> List[torch.Tensor]: def get_parameters(self) -> List[torch.Tensor]:
return [p for p in self.module_partition.parameters()] return [p for p in self.module_partition.parameters()]
@ -251,7 +249,6 @@ class WorkerBase(ABC):
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env. # TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
# TODO(jiangziyue) Define a Class for DAG. # TODO(jiangziyue) Define a Class for DAG.
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
assert self.consumer_stage_ids is not None
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device() output = self._get_future_by_device()
@ -269,20 +266,11 @@ class WorkerBase(ABC):
arg_lst, _ = self._make_args_kwargs(microbatch, merge=True) arg_lst, _ = self._make_args_kwargs(microbatch, merge=True)
# first stage assign correct input into other stages # first stage assign correct input into other stages
DAG = self.get_DAG() topo: Topo = self.get_topo()
DAG_node = DAG['input_partition'] self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_input_offsets = [] input_partition = topo.get_input_partition()
self_input_offsets = input_partition.get_output_offsets(self_partition_id)
recv_input_key = UniqueKey(microbatch_id, Phase.INPUT) recv_input_key = UniqueKey(microbatch_id, Phase.INPUT)
# notify rank which should receive extra input
offset = 0
for details in DAG_node.values():
for partition_name in details['output'].keys():
recv_rank = self.partition_name_to_pp_rank(partition_name)
if recv_rank == self.pp_rank:
self_input_offsets.append(offset)
elif recv_rank not in self.input_consumer_stage_ids:
self.input_consumer_stage_ids.append(recv_rank)
offset += 1
# set input for self rank # set input for self rank
self_arg_lst = [] self_arg_lst = []
@ -295,7 +283,7 @@ class WorkerBase(ABC):
self.work_list[key] = work_item self.work_list[key] = work_item
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
# put input tensor which other nodes need into output_list # put input tensor which other nodes need into output_list as Phase.INPUT
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
self.num_microbatches, forward_only) self.num_microbatches, forward_only)
@ -344,16 +332,10 @@ class WorkerBase(ABC):
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
if key in self.work_list: if key in self.work_list:
return return
producer_stage_ids = self.get_producer_stage_ids()
producer_stage_ids = []
with self.producer_consumer_init_lock:
self.producer_consumer_init_lock.wait_for(lambda: self._producer_consumer_initialized)
producer_stage_ids = self.producer_stage_ids
producer_num = len(producer_stage_ids) producer_num = len(producer_stage_ids)
# TODO(jiangziyue) get single value instead of the whole output
if self.need_model_input(): if self.need_model_input():
producer_num += 1 # extra one(the last one) for input_tensor producer_num += 1 # for input partition
subscribe_forward_futures: List[Future] = [None] * producer_num subscribe_forward_futures: List[Future] = [None] * producer_num
# TODO(jiangziyue) get single value instead of the whole output # TODO(jiangziyue) get single value instead of the whole output
@ -374,7 +356,6 @@ class WorkerBase(ABC):
producer_stage_id = producer_stage_ids[i] producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
#producer_partition_name = self.pp_rank_to_partition_name[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
@ -415,6 +396,44 @@ class WorkerBase(ABC):
assert key not in self.work_list assert key not in self.work_list
self.work_list[key] = work_item_from_consumer self.work_list[key] = work_item_from_consumer
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
def get_producer_stage_ids(self):
producer_stage_ids = []
rank = self.pp_rank
if not self.use_middleware():
prev_rank = rank - 1
if prev_rank >= 0:
producer_stage_ids.append(prev_rank)
else:
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
input_partition_ids = self_partition.get_input_partition_ids()
model_input_partition_id = topo.get_input_partition_id()
for partition_id in input_partition_ids:
# ignore input partition in current implementation.
# it will be specially tackled.
if partition_id != model_input_partition_id:
producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
return producer_stage_ids
def get_consumer_stage_ids(self):
consumer_stage_ids = []
rank = self.pp_rank
if not self.use_middleware():
next_rank = rank + 1
if next_rank <= self.actual_stage_num - 1:
consumer_stage_ids.append(next_rank)
else:
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
output_partition_ids = self_partition.get_output_partition_ids()
model_output_partition_id = topo.get_output_partition_id()
for partition_id in output_partition_ids:
if model_output_partition_id != partition_id:
consumer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
return consumer_stage_ids
def _get_producer_consumer(self) -> None: def _get_producer_consumer(self) -> None:
rank = self.pp_rank rank = self.pp_rank
@ -422,64 +441,32 @@ class WorkerBase(ABC):
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
# should be aranged in order, the order of the input of current forward # should be aranged in order, the order of the input of current forward
self.producer_stage_ids = [] self.producer_stage_ids = self.get_producer_stage_ids()
self.consumer_stage_ids = [] self.consumer_stage_ids = self.get_consumer_stage_ids()
if not self.use_middleware():
# Just for demo
prev_rank = rank - 1
next_rank = rank + 1
if prev_rank >= 0:
self.producer_stage_ids.append(prev_rank)
if next_rank <= self.actual_stage_num - 1:
self.consumer_stage_ids.append(next_rank)
else:
self.input_consumer_stage_ids = []
DAG = self.get_DAG()
DAG_node_name = self.pp_rank_to_partition_name(rank)
DAG_node = DAG[DAG_node_name]
for partition_name in DAG_node['input'].keys():
if partition_name == 'MODEL_INPUT':
self._is_input = True
else:
prev_rank = self.partition_name_to_pp_rank(partition_name)
self.producer_stage_ids.append(prev_rank)
for partition_name in DAG_node['output'].keys():
if partition_name == 'MODEL_OUTPUT':
self._is_output = True
else:
next_rank = self.partition_name_to_pp_rank(partition_name)
self.consumer_stage_ids.append(next_rank)
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
with self.producer_consumer_init_lock:
self._producer_consumer_initialized = True
self.producer_consumer_init_lock.notify_all()
# TODO(jiangziyue) Define a Class for DAG. # TODO(jiangziyue) Define a Class for DAG.
def pp_rank_to_partition_name(self, pp_rank: int): def pp_rank_to_partition_id(self, pp_rank: int, topo: Topo):
prefix = 'submod_' partition_ids = topo.get_mid_partition_ids()
partition_name = prefix + str(pp_rank) return partition_ids[pp_rank]
return partition_name
# TODO(jiangziyue) Define a Class for DAG. # TODO(jiangziyue) Define a Class for DAG.
def partition_name_to_pp_rank(self, partition_name: str) -> int: def partition_id_to_pp_rank(self, partition_id: int, topo: Topo):
prefix = 'submod_' partition_ids = topo.get_mid_partition_ids()
pp_rank = int(partition_name.split(prefix)[-1]) for i, id in enumerate(partition_ids):
return pp_rank if id == partition_id:
return i
def get_DAG(self): def get_topo(self):
with self.partition_condition_lock: with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
if hasattr(self.module_partition, '_DAG'): if hasattr(self.module_partition, '_topo'):
return self.module_partition._DAG return self.module_partition._topo
else: else:
return None return None
def use_middleware(self): def use_middleware(self):
DAG = self.get_DAG() topo = self.get_topo()
return DAG is not None return topo is not None
# TODO(jiangziyue) get single value instead of the whole output # TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs(self, args_or_kwargs): def _get_real_args_kwargs(self, args_or_kwargs):
@ -497,58 +484,48 @@ class WorkerBase(ABC):
if args_or_kwargs is not None: if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict): if isinstance(args_or_kwargs, dict):
pass pass
else: else:
flatten_args = [] flatten_args = []
if self.is_first_stage(): if self.is_first_stage():
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
# TODO get by offset # TODO get by offset
else: else:
DAG = self.get_DAG() topo: Topo = self.get_topo()
producer_outputs = {} self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
cur_DAG_node_name = self.pp_rank_to_partition_name(self.pp_rank) self_partition: Partition = topo.get_partition_by_id(self_partition_id)
#cur_DAG_node = DAG[self.pp_rank_to_partition_name(self.pp_rank)] model_input_partition_id = topo.get_input_partition_id()
for i, args_from_one_mod in enumerate(args_or_kwargs): input_vals = self_partition.get_input_vals()
producer_output_offsets = [] producer_stage_ids = self.get_producer_stage_ids()
if self.need_model_input(): if self.need_model_input():
if i == 0: # 0 for data from input batch
producer_DAG_node = DAG['input_partition'] # >= 1 for data from prev stages
producer_partition_name = 'MODEL_INPUT' base = 1
offset = 0 else:
for arg_info in producer_DAG_node.values(): # data from prev stages
if cur_DAG_node_name in arg_info['output']: base = 0
producer_output_offsets.append(offset) for val in input_vals:
offset += 1 val_pos = val.get()
else: src_partition_id = val_pos.partition_id
producer_rank = self.producer_stage_ids[i-1] src_offset = val_pos.offset
producer_partition_name = self.pp_rank_to_partition_name(producer_rank) src_index = base
producer_DAG_node = DAG[producer_partition_name] src_partition = topo.get_partition_by_id(src_partition_id)
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name] output_len = len(src_partition.get_output_vals())
# data from not-input partition
if src_partition_id != model_input_partition_id:
src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo)
src_index = base
for i, stage_id in enumerate(producer_stage_ids):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if output_len == 1:
target = args_or_kwargs[src_index]
else: else:
producer_rank = self.producer_stage_ids[i] target = args_or_kwargs[src_index][src_offset]
producer_partition_name = self.pp_rank_to_partition_name(producer_rank) flatten_args.append(target)
producer_DAG_node = DAG[producer_partition_name]
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
if producer_partition_name != 'MODEL_INPUT' and DAG[producer_partition_name]['output_len'] == 1:
producer_outputs[producer_partition_name] = [args_from_one_mod]
else:
producer_outputs[producer_partition_name] = [args_from_one_mod[offset] for offset in producer_output_offsets]
cur_DAG_node_input = DAG[cur_DAG_node_name]['input']
def get_input_len(DAG_node_input):
res = 0
for offsets in DAG_node_input.values():
res += len(offsets)
return res
input_len = get_input_len(cur_DAG_node_input)
flatten_args = [None] * input_len
for producer_partition_name, args_input_offsets in cur_DAG_node_input.items():
for i, args_input_offset in enumerate(args_input_offsets):
flatten_args[args_input_offset] = producer_outputs[producer_partition_name][i]
args_or_kwargs = flatten_args args_or_kwargs = flatten_args
return args_or_kwargs return args_or_kwargs
@ -565,7 +542,15 @@ class WorkerBase(ABC):
return self.pp_rank == self.actual_stage_num - 1 return self.pp_rank == self.actual_stage_num - 1
def need_model_input(self): def need_model_input(self):
return not self.is_first_stage() and self._is_input need_input = False
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition = topo.get_partition_by_id(self_partition_id)
partition_inputs = self_partition.get_input_partition_ids()
model_input_partition_id = topo.get_input_partition_id()
if model_input_partition_id in partition_inputs:
need_input = True
return not self.is_first_stage() and need_input
def _default_data_process_func(self, args_kwargs): def _default_data_process_func(self, args_kwargs):
if self.is_first_stage(): if self.is_first_stage():

View File

@ -4,6 +4,7 @@ from torch import nn
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.pipeline.middleware.adaptor import get_fx_topology
from rpc_test_utils import rpc_run, parse_args, MLP from rpc_test_utils import rpc_run, parse_args, MLP
from functools import partial from functools import partial
@ -18,8 +19,12 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
graph = tracer.trace(root=model, meta_args=meta_args) graph = tracer.trace(root=model, meta_args=meta_args)
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
annotated_model = balanced_split_pass(gm, stage_num) annotated_model = balanced_split_pass(gm, stage_num)
split_model, _ = split_with_split_nodes_pass(annotated_model, merge_output=True) top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
return list(split_model.children())[pp_rank] topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
return split_submodules[pp_rank+1]
def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024) torch.manual_seed(1024)