diff --git a/colossalai/pipeline/middleware/adaptor/fx.py b/colossalai/pipeline/middleware/adaptor/fx.py index 4351a6b49..8437c5194 100644 --- a/colossalai/pipeline/middleware/adaptor/fx.py +++ b/colossalai/pipeline/middleware/adaptor/fx.py @@ -108,7 +108,7 @@ def get_topology(gm: GraphModule): p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) topo_input_partition.add_output_val(p_output_val) topo.set_partitions(partition_id=0, partition=topo_input_partition) - topo.set_input_partition(partition_id=0) + topo.set_input_partition_id(partition_id=0) for i, partition in enumerate(partitions): topo_mid_partition = Partition() @@ -140,6 +140,6 @@ def get_topology(gm: GraphModule): torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val( find_input_in_partition(n, partitions, input_partitions))) topo.set_partitions(partition_id=1, partition=topo_output_partition) - topo.set_output_partition(partition_id=1) + topo.set_output_partition_id(partition_id=1) return topo \ No newline at end of file diff --git a/colossalai/pipeline/middleware/topo.py b/colossalai/pipeline/middleware/topo.py index e9d97b0b7..e798e2ed9 100644 --- a/colossalai/pipeline/middleware/topo.py +++ b/colossalai/pipeline/middleware/topo.py @@ -71,6 +71,36 @@ class Partition(object): def get_output_vals(self): return self._output_vals + + # get the output offsets sent to dst_partition_id + def get_output_offsets(self, dst_partition_id): + res = [] + for offset, output_val in enumerate(self._output_vals): + outputs = output_val.get() + for val_pos in outputs: + if val_pos.partition_id == dst_partition_id: + res.append(offset) + + return res + + # get all input dst partition_ids + def get_input_partition_ids(self): + res = [] + for input_val in self._input_vals: + val_pos = input_val.get() + if val_pos.partition_id not in res: + res.append(val_pos.partition_id) + return res + + # get all output dst partition_ids + def get_output_partition_ids(self): + res = [] + for output_val in self._output_vals: + outputs = output_val.get() + for val_pos in outputs: + if val_pos.partition_id not in res: + res.append(val_pos.partition_id) + return res def __str__(self) -> str: res = '' @@ -107,11 +137,17 @@ class Topo(object): self._input_partition_id = input_partition_id self._output_partition_id = output_partition_id - def set_input_partition(self, partition_id: int): + def set_input_partition_id(self, partition_id: int): self._input_partition_id = partition_id - def set_output_partition(self, partition_id: int): + def set_output_partition_id(self, partition_id: int): self._output_partition_id = partition_id + + def get_input_partition_id(self): + return self._input_partition_id + + def get_output_partition_id(self): + return self._output_partition_id def set_partitions(self, partition_id: int, partition: Partition): self._partitions[partition_id] = partition @@ -124,6 +160,9 @@ class Topo(object): res[partition_id] = partition return res + def get_mid_partition_ids(self): + return list(self.get_mid_partitions().keys()) + def get_input_partition(self): if self._input_partition_id is not None: return self._partitions[self._input_partition_id] @@ -133,6 +172,9 @@ class Topo(object): if self._output_partition_id is not None: return self._partitions[self._output_partition_id] return None + + def get_partition_by_id(self, partition_id): + return self._partitions[partition_id] def __str__(self) -> str: res = '' diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 6a6c2379b..e28a31624 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -11,6 +11,7 @@ import torch.distributed.rpc as rpc from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map, split_batch, tensor_shape_list, type_detail) +from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo from torch import autograd, nn, optim from torch._C._distributed_rpc import PyRRef from torch.futures import Future @@ -128,7 +129,6 @@ class WorkerBase(ABC): # topology info self.producer_stage_ids: List[int] = None self.consumer_stage_ids: List[int] = None - self.input_consumer_stage_ids: List[int] = None # module partitions self.partition_fn = partition_fn @@ -137,9 +137,7 @@ class WorkerBase(ABC): self.metric = metric # middleware info - self._is_input = False self._is_output = False - self._producer_consumer_initialized = False # context to maintain loop self._initialize_context_container() @@ -170,7 +168,6 @@ class WorkerBase(ABC): self.work_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock()) self.label_lock = threading.Condition(threading.Lock()) - self.producer_consumer_init_lock = threading.Condition(threading.Lock()) def _initialize_partition(self): partition_fn = self.partition_fn @@ -207,6 +204,7 @@ class WorkerBase(ABC): self.output_list.pop(key) return output + def get_parameters(self) -> List[torch.Tensor]: return [p for p in self.module_partition.parameters()] @@ -251,7 +249,6 @@ class WorkerBase(ABC): # TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env. # TODO(jiangziyue) Define a Class for DAG. def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): - assert self.consumer_stage_ids is not None key = UniqueKey(microbatch_id, Phase.FORWARD) output = self._get_future_by_device() @@ -269,20 +266,11 @@ class WorkerBase(ABC): arg_lst, _ = self._make_args_kwargs(microbatch, merge=True) # first stage assign correct input into other stages - DAG = self.get_DAG() - DAG_node = DAG['input_partition'] - self_input_offsets = [] + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + input_partition = topo.get_input_partition() + self_input_offsets = input_partition.get_output_offsets(self_partition_id) recv_input_key = UniqueKey(microbatch_id, Phase.INPUT) - # notify rank which should receive extra input - offset = 0 - for details in DAG_node.values(): - for partition_name in details['output'].keys(): - recv_rank = self.partition_name_to_pp_rank(partition_name) - if recv_rank == self.pp_rank: - self_input_offsets.append(offset) - elif recv_rank not in self.input_consumer_stage_ids: - self.input_consumer_stage_ids.append(recv_rank) - offset += 1 # set input for self rank self_arg_lst = [] @@ -295,7 +283,7 @@ class WorkerBase(ABC): self.work_list[key] = work_item self.work_list_condition_lock.notify_all() - # put input tensor which other nodes need into output_list + # put input tensor which other nodes need into output_list as Phase.INPUT work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, self.num_microbatches, forward_only) @@ -344,16 +332,10 @@ class WorkerBase(ABC): key = UniqueKey(microbatch_id, Phase.FORWARD) if key in self.work_list: return - - producer_stage_ids = [] - with self.producer_consumer_init_lock: - self.producer_consumer_init_lock.wait_for(lambda: self._producer_consumer_initialized) - producer_stage_ids = self.producer_stage_ids + producer_stage_ids = self.get_producer_stage_ids() producer_num = len(producer_stage_ids) - - # TODO(jiangziyue) get single value instead of the whole output if self.need_model_input(): - producer_num += 1 # extra one(the last one) for input_tensor + producer_num += 1 # for input partition subscribe_forward_futures: List[Future] = [None] * producer_num # TODO(jiangziyue) get single value instead of the whole output @@ -374,7 +356,6 @@ class WorkerBase(ABC): producer_stage_id = producer_stage_ids[i] producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] - #producer_partition_name = self.pp_rank_to_partition_name[producer_stage_id] subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, @@ -415,6 +396,44 @@ class WorkerBase(ABC): assert key not in self.work_list self.work_list[key] = work_item_from_consumer self.work_list_condition_lock.notify_all() + + def get_producer_stage_ids(self): + producer_stage_ids = [] + rank = self.pp_rank + if not self.use_middleware(): + prev_rank = rank - 1 + if prev_rank >= 0: + producer_stage_ids.append(prev_rank) + else: + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + input_partition_ids = self_partition.get_input_partition_ids() + model_input_partition_id = topo.get_input_partition_id() + for partition_id in input_partition_ids: + # ignore input partition in current implementation. + # it will be specially tackled. + if partition_id != model_input_partition_id: + producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo)) + return producer_stage_ids + + def get_consumer_stage_ids(self): + consumer_stage_ids = [] + rank = self.pp_rank + if not self.use_middleware(): + next_rank = rank + 1 + if next_rank <= self.actual_stage_num - 1: + consumer_stage_ids.append(next_rank) + else: + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_partition_ids = self_partition.get_output_partition_ids() + model_output_partition_id = topo.get_output_partition_id() + for partition_id in output_partition_ids: + if model_output_partition_id != partition_id: + consumer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo)) + return consumer_stage_ids def _get_producer_consumer(self) -> None: rank = self.pp_rank @@ -422,64 +441,32 @@ class WorkerBase(ABC): assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" # should be aranged in order, the order of the input of current forward - self.producer_stage_ids = [] - self.consumer_stage_ids = [] - - if not self.use_middleware(): - # Just for demo - prev_rank = rank - 1 - next_rank = rank + 1 - if prev_rank >= 0: - self.producer_stage_ids.append(prev_rank) - if next_rank <= self.actual_stage_num - 1: - self.consumer_stage_ids.append(next_rank) - else: - self.input_consumer_stage_ids = [] - DAG = self.get_DAG() - DAG_node_name = self.pp_rank_to_partition_name(rank) - DAG_node = DAG[DAG_node_name] - for partition_name in DAG_node['input'].keys(): - if partition_name == 'MODEL_INPUT': - self._is_input = True - else: - prev_rank = self.partition_name_to_pp_rank(partition_name) - self.producer_stage_ids.append(prev_rank) - - for partition_name in DAG_node['output'].keys(): - if partition_name == 'MODEL_OUTPUT': - self._is_output = True - else: - next_rank = self.partition_name_to_pp_rank(partition_name) - self.consumer_stage_ids.append(next_rank) - - # TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env. - with self.producer_consumer_init_lock: - self._producer_consumer_initialized = True - self.producer_consumer_init_lock.notify_all() + self.producer_stage_ids = self.get_producer_stage_ids() + self.consumer_stage_ids = self.get_consumer_stage_ids() # TODO(jiangziyue) Define a Class for DAG. - def pp_rank_to_partition_name(self, pp_rank: int): - prefix = 'submod_' - partition_name = prefix + str(pp_rank) - return partition_name + def pp_rank_to_partition_id(self, pp_rank: int, topo: Topo): + partition_ids = topo.get_mid_partition_ids() + return partition_ids[pp_rank] # TODO(jiangziyue) Define a Class for DAG. - def partition_name_to_pp_rank(self, partition_name: str) -> int: - prefix = 'submod_' - pp_rank = int(partition_name.split(prefix)[-1]) - return pp_rank + def partition_id_to_pp_rank(self, partition_id: int, topo: Topo): + partition_ids = topo.get_mid_partition_ids() + for i, id in enumerate(partition_ids): + if id == partition_id: + return i - def get_DAG(self): + def get_topo(self): with self.partition_condition_lock: self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) - if hasattr(self.module_partition, '_DAG'): - return self.module_partition._DAG + if hasattr(self.module_partition, '_topo'): + return self.module_partition._topo else: return None def use_middleware(self): - DAG = self.get_DAG() - return DAG is not None + topo = self.get_topo() + return topo is not None # TODO(jiangziyue) get single value instead of the whole output def _get_real_args_kwargs(self, args_or_kwargs): @@ -497,58 +484,48 @@ class WorkerBase(ABC): if args_or_kwargs is not None: if isinstance(args_or_kwargs, dict): pass - else: + else: flatten_args = [] if self.is_first_stage(): pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) # TODO get by offset else: - DAG = self.get_DAG() - producer_outputs = {} - cur_DAG_node_name = self.pp_rank_to_partition_name(self.pp_rank) - #cur_DAG_node = DAG[self.pp_rank_to_partition_name(self.pp_rank)] - for i, args_from_one_mod in enumerate(args_or_kwargs): - producer_output_offsets = [] - if self.need_model_input(): - if i == 0: - producer_DAG_node = DAG['input_partition'] - producer_partition_name = 'MODEL_INPUT' - offset = 0 - for arg_info in producer_DAG_node.values(): - if cur_DAG_node_name in arg_info['output']: - producer_output_offsets.append(offset) - offset += 1 - else: - producer_rank = self.producer_stage_ids[i-1] - producer_partition_name = self.pp_rank_to_partition_name(producer_rank) - producer_DAG_node = DAG[producer_partition_name] - producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name] - + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + model_input_partition_id = topo.get_input_partition_id() + input_vals = self_partition.get_input_vals() + producer_stage_ids = self.get_producer_stage_ids() + if self.need_model_input(): + # 0 for data from input batch + # >= 1 for data from prev stages + base = 1 + else: + # data from prev stages + base = 0 + for val in input_vals: + val_pos = val.get() + src_partition_id = val_pos.partition_id + src_offset = val_pos.offset + src_index = base + src_partition = topo.get_partition_by_id(src_partition_id) + output_len = len(src_partition.get_output_vals()) + # data from not-input partition + if src_partition_id != model_input_partition_id: + src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo) + src_index = base + for i, stage_id in enumerate(producer_stage_ids): + if stage_id == src_stage_id: + src_index += i + break + else: # data from input partition + src_index = 0 + # when output_len = 1, not iterable + if output_len == 1: + target = args_or_kwargs[src_index] else: - producer_rank = self.producer_stage_ids[i] - producer_partition_name = self.pp_rank_to_partition_name(producer_rank) - producer_DAG_node = DAG[producer_partition_name] - producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name] - - if producer_partition_name != 'MODEL_INPUT' and DAG[producer_partition_name]['output_len'] == 1: - producer_outputs[producer_partition_name] = [args_from_one_mod] - else: - producer_outputs[producer_partition_name] = [args_from_one_mod[offset] for offset in producer_output_offsets] - - cur_DAG_node_input = DAG[cur_DAG_node_name]['input'] - - def get_input_len(DAG_node_input): - res = 0 - for offsets in DAG_node_input.values(): - res += len(offsets) - return res - - input_len = get_input_len(cur_DAG_node_input) - flatten_args = [None] * input_len - for producer_partition_name, args_input_offsets in cur_DAG_node_input.items(): - for i, args_input_offset in enumerate(args_input_offsets): - flatten_args[args_input_offset] = producer_outputs[producer_partition_name][i] - + target = args_or_kwargs[src_index][src_offset] + flatten_args.append(target) args_or_kwargs = flatten_args return args_or_kwargs @@ -565,7 +542,15 @@ class WorkerBase(ABC): return self.pp_rank == self.actual_stage_num - 1 def need_model_input(self): - return not self.is_first_stage() and self._is_input + need_input = False + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition = topo.get_partition_by_id(self_partition_id) + partition_inputs = self_partition.get_input_partition_ids() + model_input_partition_id = topo.get_input_partition_id() + if model_input_partition_id in partition_inputs: + need_input = True + return not self.is_first_stage() and need_input def _default_data_process_func(self, args_kwargs): if self.is_first_stage(): diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_pipeline/test_middleware_1f1b.py index ea9a3c16e..d138f8cdd 100644 --- a/tests/test_pipeline/test_middleware_1f1b.py +++ b/tests/test_pipeline/test_middleware_1f1b.py @@ -4,6 +4,7 @@ from torch import nn from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass from colossalai.fx import ColoTracer +from colossalai.pipeline.middleware.adaptor import get_fx_topology from rpc_test_utils import rpc_run, parse_args, MLP from functools import partial @@ -18,8 +19,12 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): graph = tracer.trace(root=model, meta_args=meta_args) gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) annotated_model = balanced_split_pass(gm, stage_num) - split_model, _ = split_with_split_nodes_pass(annotated_model, merge_output=True) - return list(split_model.children())[pp_rank] + top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) + topo = get_fx_topology(top_module) + for submodule in split_submodules: + if isinstance(submodule, torch.fx.GraphModule): + setattr(submodule, '_topo', topo) + return split_submodules[pp_rank+1] def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): torch.manual_seed(1024)