diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index e28a31624..8854c73a9 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -16,7 +16,6 @@ from torch import autograd, nn, optim from torch._C._distributed_rpc import PyRRef from torch.futures import Future - class Phase(Enum): FORWARD = 0 BACKWARD = 1 @@ -136,9 +135,6 @@ class WorkerBase(ABC): self.criterion = criterion self.metric = metric - # middleware info - self._is_output = False - # context to maintain loop self._initialize_context_container() @@ -190,21 +186,33 @@ class WorkerBase(ABC): with self.output_list_condition_lock: self.output_list_condition_lock.wait_for(lambda: key in self.output_list) output_work_item = self.output_list[key] - + self.output_list.pop(key) + + output_work_item.refcount += 1 + refcount = output_work_item.refcount output = output_work_item.output + + if output_work_item.phase != Phase.INPUT: + # lifecycle management for DAG scheduler + lifecycle = len(self.get_consumer_stage_ids()) + if self.is_model_output(): # an extra reference for scheduler collecting results + lifecycle += 1 + with self.output_list_condition_lock: + # all consumers have been satisfied, the work_item can be released + # or put it into work list again. + if refcount < lifecycle: + self.output_list[key] = output_work_item + self.output_list_condition_lock.notify_all() + else: + with self.output_list_condition_lock: + self.output_list[key] = output_work_item + self.output_list_condition_lock.notify_all() + if isinstance(output, Future): output = output.wait() - - # output_work_item.refcount += 1 - - # TODO(jiangziyue) redesign lifecycle management for DAG scheduler - # all consumers have been satisfied, the work_item can be released - with self.output_list_condition_lock: - if output_work_item.refcount >= len(self.consumer_stage_ids): - self.output_list.pop(key) + return output - def get_parameters(self) -> List[torch.Tensor]: return [p for p in self.module_partition.parameters()] @@ -246,8 +254,6 @@ class WorkerBase(ABC): raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}") # just for first pp_rank - # TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env. - # TODO(jiangziyue) Define a Class for DAG. def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): key = UniqueKey(microbatch_id, Phase.FORWARD) output = self._get_future_by_device() @@ -311,9 +317,8 @@ class WorkerBase(ABC): self.work_list[key] = work_item self.work_list_condition_lock.notify_all() - - # TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env. - def subscribe_producer(self, microbatch_id: int, forward_only: bool): + + def _subscribe_producer(self, microbatch_id: int, forward_only: bool): """ You should call this function asynchronously """ @@ -328,10 +333,6 @@ class WorkerBase(ABC): producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) else: - with self.work_list_condition_lock: - key = UniqueKey(microbatch_id, Phase.FORWARD) - if key in self.work_list: - return producer_stage_ids = self.get_producer_stage_ids() producer_num = len(producer_stage_ids) if self.need_model_input(): @@ -360,11 +361,19 @@ class WorkerBase(ABC): work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, microbatch_id, None, self.num_microbatches, forward_only) - - # add work_item to work_list + + return work_item_from_producer + + # TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one. + def subscribe_producer(self, microbatch_id: int, forward_only: bool): + key = UniqueKey(microbatch_id, Phase.FORWARD) with self.work_list_condition_lock: - key = UniqueKey(microbatch_id, Phase.FORWARD) if key not in self.work_list: + # On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer + # can only be executed once for every producer-consumer stage pair, which is necessary + # to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same + # lock of work_item queue operation gurantees the consistency of lifecycle counter. + work_item_from_producer = self._subscribe_producer(microbatch_id, forward_only) self.work_list[key] = work_item_from_producer self.work_list_condition_lock.notify_all() @@ -444,12 +453,10 @@ class WorkerBase(ABC): self.producer_stage_ids = self.get_producer_stage_ids() self.consumer_stage_ids = self.get_consumer_stage_ids() - # TODO(jiangziyue) Define a Class for DAG. def pp_rank_to_partition_id(self, pp_rank: int, topo: Topo): partition_ids = topo.get_mid_partition_ids() return partition_ids[pp_rank] - # TODO(jiangziyue) Define a Class for DAG. def partition_id_to_pp_rank(self, partition_id: int, topo: Topo): partition_ids = topo.get_mid_partition_ids() for i, id in enumerate(partition_ids): @@ -551,6 +558,9 @@ class WorkerBase(ABC): if model_input_partition_id in partition_inputs: need_input = True return not self.is_first_stage() and need_input + + def is_model_output(self): + return self.is_last_stage() def _default_data_process_func(self, args_kwargs): if self.is_first_stage(): @@ -748,7 +758,8 @@ class WorkerBase(ABC): # move current work item to output_list to activate subscribe in advance with self.work_list_condition_lock: - work_item = self.work_list.pop(work_item_key) + #self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list) + work_item = self.work_list[work_item_key] with self.output_list_condition_lock: # assert work_item_key not in self.output_list @@ -758,6 +769,8 @@ class WorkerBase(ABC): consume_result = self._consume_work_item_by_phase(work_item) work_item.output.set_result(consume_result) + with self.work_list_condition_lock: + self.work_list.pop(work_item_key) # if is last step in one batch reset context and do step if self._is_last_step(work_item): diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index f1a4116be..853efde3f 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -32,6 +32,21 @@ class MLP(nn.Module): for layer in self.layers: x = layer(x) return x + +class DAG_MLP(nn.Module): + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + self.dag_layer = nn.Linear(dim, dim, bias=False) + + for _ in range(layers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, x, y): + for layer in self.layers: + x = layer(x) + y = self.dag_layer(y) + return x, y class RpcTestModel(nn.Module): diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_pipeline/test_middleware_1f1b.py index d138f8cdd..c4fb9b094 100644 --- a/tests/test_pipeline/test_middleware_1f1b.py +++ b/tests/test_pipeline/test_middleware_1f1b.py @@ -1,16 +1,26 @@ import torch -from torch import nn +import pytest +import os +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc +from torch import nn +from torch._C._distributed_rpc import _is_current_rpc_agent_set +from colossalai import launch +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass from colossalai.fx import ColoTracer from colossalai.pipeline.middleware.adaptor import get_fx_topology -from rpc_test_utils import rpc_run, parse_args, MLP +from rpc_test_utils import MLP, DAG_MLP from functools import partial +from colossalai.testing import parameterize, rerun_if_address_is_in_use # global variable for model created batch_size = 16 dim = 10 +rpc_is_initialized = _is_current_rpc_agent_set def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): model.eval() @@ -26,40 +36,82 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): setattr(submodule, '_topo', topo) return split_submodules[pp_rank+1] -def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): +def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): torch.manual_seed(1024) - model = MLP(dim, stage_num * 3) partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) return partition -def run_master(args): +def run_master(model_cls, world_size): torch.manual_seed(100) - epoch = args.epoch - device = args.device - stage_num = args.world_size - chunk = args.chunk - num_microbatches = args.num_microbatches - use_checkpoint = args.use_checkpoint - - input_sample = torch.randn((batch_size, dim), device=device) + epoch = 10 + device = 'cuda' + stage_num = world_size + chunk = 1 + num_microbatches = 8 + use_checkpoint = 'store_true' - def data_gen(): - x = torch.zeros((batch_size, dim)) - kwargs = dict(x=x) - return kwargs + if model_cls == MLP: + def data_gen(): + x = torch.zeros((batch_size, dim)) + kwargs = dict(x=x) + return kwargs + model = model_cls(dim, stage_num * 3) + elif model_cls == DAG_MLP: + def data_gen(): + x = torch.zeros((batch_size, dim)) + y = torch.zeros((batch_size, dim)) + kwargs = dict(x=x, y=y) + return kwargs + model = model_cls(dim, stage_num * 3) + else: + pass data_kwargs = data_gen() - engine = OneFOneBPipelineEngine(partition_fn=partial(partition, data_kwargs), + + engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs), stage_num=stage_num, num_microbatches=num_microbatches, device=device, chunk=chunk, - checkpoint=use_checkpoint) + checkpoint=use_checkpoint,) for _ in range(epoch): - logits = engine.forward_backward({'x': input_sample}, forward_only=True) + input_x = torch.randn((batch_size, dim), device=device) + input_y = torch.randn((batch_size, dim), device=device) + logits = engine.forward_backward({'x': input_x, 'y': input_y}, forward_only=True) + +def run_worker(rank, model_cls, world_size, master_func): + master_addr = 'localhost' + master_port = 29020 + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = str(master_port) + + disable_existing_loggers() + + launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) + ppg.set_global_info(rank=rank, + world_size=world_size, + dp_degree=1, + tp_degree=1, + num_worker_threads=128, + device='cuda') + + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(model_cls, world_size) + # barrier here + if rpc_is_initialized(): + rpc.shutdown() + +@pytest.mark.skip("skip due to CI torch version 1.11") +@parameterize('model_cls', [MLP, DAG_MLP]) +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp_middleware_fwd(model_cls): + world_size = 4 + master_func = run_master + mp.spawn(run_worker, args=(model_cls, world_size, master_func), nprocs=world_size) if __name__ == "__main__": - args = parse_args() - rpc_run(args, run_master) \ No newline at end of file + test_pp_middleware_fwd()