diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 8854c73a9..ae1cbb0c4 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -8,20 +8,29 @@ from typing import Any, Callable, Dict, List, Tuple import torch import torch.distributed.rpc as rpc -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map, - split_batch, tensor_shape_list, type_detail) -from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo from torch import autograd, nn, optim from torch._C._distributed_rpc import PyRRef from torch.futures import Future +from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc.utils import ( + get_batch_lengths, + pytree_filter, + pytree_map, + split_batch, + tensor_shape_list, + type_detail, +) + + class Phase(Enum): FORWARD = 0 BACKWARD = 1 UPDATE = 2 INPUT = 3 + class UniqueKey: __slots__ = ('microbatch_id', 'phase') microbatch_id: int @@ -134,6 +143,7 @@ class WorkerBase(ABC): self.partition_args = partition_args self.criterion = criterion self.metric = metric + self.reset = False # context to maintain loop self._initialize_context_container() @@ -164,6 +174,7 @@ class WorkerBase(ABC): self.work_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock()) self.label_lock = threading.Condition(threading.Lock()) + self.reset_condition = threading.Condition(threading.Lock()) def _initialize_partition(self): partition_fn = self.partition_fn @@ -182,20 +193,23 @@ class WorkerBase(ABC): # construction of partition is executed after the registion of pp_rank_to_worker_rref self._initialize_partition() - def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any: + # res_use works for lifecycle counter, + # if ref_use is True, lifecycle won't add. + def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any: with self.output_list_condition_lock: self.output_list_condition_lock.wait_for(lambda: key in self.output_list) output_work_item = self.output_list[key] - self.output_list.pop(key) - - output_work_item.refcount += 1 + self.output_list.pop(key) + + if not ref_use: + output_work_item.refcount += 1 refcount = output_work_item.refcount output = output_work_item.output - if output_work_item.phase != Phase.INPUT: + if output_work_item.phase == Phase.FORWARD: # lifecycle management for DAG scheduler lifecycle = len(self.get_consumer_stage_ids()) - if self.is_model_output(): # an extra reference for scheduler collecting results + if self.is_model_output(): # an extra reference for scheduler collecting results lifecycle += 1 with self.output_list_condition_lock: # all consumers have been satisfied, the work_item can be released @@ -203,14 +217,24 @@ class WorkerBase(ABC): if refcount < lifecycle: self.output_list[key] = output_work_item self.output_list_condition_lock.notify_all() + elif output_work_item.phase == Phase.BACKWARD: + lifecycle = len(self.get_producer_stage_ids()) + if self._is_last_step(output_work_item): + lifecycle += 1 # an extra reference for scheduler collecting results + with self.output_list_condition_lock: + # all producers have been satisfied, the work_item can be released + # or put it into work list again. + if refcount < lifecycle: + self.output_list[key] = output_work_item + self.output_list_condition_lock.notify_all() else: with self.output_list_condition_lock: self.output_list[key] = output_work_item self.output_list_condition_lock.notify_all() - + if isinstance(output, Future): output = output.wait() - + return output def get_parameters(self) -> List[torch.Tensor]: @@ -257,13 +281,13 @@ class WorkerBase(ABC): def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): key = UniqueKey(microbatch_id, Phase.FORWARD) output = self._get_future_by_device() - + if not self.use_middleware(): # make args and kwargs args, kwargs = self._make_args_kwargs(microbatch) work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, - self.num_microbatches, forward_only) + self.num_microbatches, forward_only) with self.work_list_condition_lock: self.work_list[key] = work_item self.work_list_condition_lock.notify_all() @@ -284,14 +308,14 @@ class WorkerBase(ABC): self_arg_lst.append(arg_lst[off]) work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None, - self.num_microbatches, forward_only) + self.num_microbatches, forward_only) with self.work_list_condition_lock: self.work_list[key] = work_item self.work_list_condition_lock.notify_all() # put input tensor which other nodes need into output_list as Phase.INPUT work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, - self.num_microbatches, forward_only) + self.num_microbatches, forward_only) with self.output_list_condition_lock: self.output_list[recv_input_key] = work_item_remote @@ -317,7 +341,7 @@ class WorkerBase(ABC): self.work_list[key] = work_item self.work_list_condition_lock.notify_all() - + def _subscribe_producer(self, microbatch_id: int, forward_only: bool): """ You should call this function asynchronously @@ -336,7 +360,7 @@ class WorkerBase(ABC): producer_stage_ids = self.get_producer_stage_ids() producer_num = len(producer_stage_ids) if self.need_model_input(): - producer_num += 1 # for input partition + producer_num += 1 # for input partition subscribe_forward_futures: List[Future] = [None] * producer_num # TODO(jiangziyue) get single value instead of the whole output @@ -344,26 +368,28 @@ class WorkerBase(ABC): producer_stage_id = 0 producer_output_key = UniqueKey(microbatch_id, Phase.INPUT) producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] - subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) + subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) - for i in range(0, producer_num-1): + for i in range(0, producer_num - 1): producer_stage_id = producer_stage_ids[i] producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] - subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) + subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key) else: for i in range(producer_num): producer_stage_id = producer_stage_ids[i] producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] - subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) + subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key) work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, - microbatch_id, None, self.num_microbatches, forward_only) - + microbatch_id, None, self.num_microbatches, forward_only) + return work_item_from_producer - + # TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one. def subscribe_producer(self, microbatch_id: int, forward_only: bool): key = UniqueKey(microbatch_id, Phase.FORWARD) @@ -377,20 +403,20 @@ class WorkerBase(ABC): self.work_list[key] = work_item_from_producer self.work_list_condition_lock.notify_all() - def subscribe_consumer(self, microbatch_id: int): + def _subscribe_consumer(self, microbatch_id: int): """ You should call this function asynchronously """ - assert self.producer_stage_ids is not None - consumer_num = len(self.consumer_stage_ids) - assert consumer_num > 0, "only stage that has consumers can subscribe comsumers" - stage_id = self.pp_rank - subscribe_backward_futures: List[Future] = [None] * consumer_num output = self._get_future_by_device() - + if not self.use_middleware(): + consumer_stage_ids = self.consumer_stage_ids + else: + consumer_stage_ids = self.get_consumer_stage_ids() + consumer_num = len(consumer_stage_ids) + subscribe_backward_futures: List[Future] = [None] * consumer_num for i in range(consumer_num): - consumer_stage_id = self.consumer_stage_ids[i] + consumer_stage_id = consumer_stage_ids[i] consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key) @@ -399,13 +425,20 @@ class WorkerBase(ABC): work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, microbatch_id, None, self.num_microbatches, False) - # add work_item to work_list + return work_item_from_consumer + + def subscribe_consumer(self, microbatch_id: int): + key = UniqueKey(microbatch_id, Phase.BACKWARD) with self.work_list_condition_lock: - key = UniqueKey(microbatch_id, Phase.BACKWARD) - assert key not in self.work_list - self.work_list[key] = work_item_from_consumer - self.work_list_condition_lock.notify_all() - + if key not in self.work_list: + # On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer + # can only be executed once for every producer-consumer stage pair, which is necessary + # to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same + # lock of work_item queue operation gurantees the consistency of lifecycle counter. + work_item_from_consumer = self._subscribe_consumer(microbatch_id) + self.work_list[key] = work_item_from_consumer + self.work_list_condition_lock.notify_all() + def get_producer_stage_ids(self): producer_stage_ids = [] rank = self.pp_rank @@ -425,7 +458,7 @@ class WorkerBase(ABC): if partition_id != model_input_partition_id: producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo)) return producer_stage_ids - + def get_consumer_stage_ids(self): consumer_stage_ids = [] rank = self.pp_rank @@ -462,7 +495,7 @@ class WorkerBase(ABC): for i, id in enumerate(partition_ids): if id == partition_id: return i - + def get_topo(self): with self.partition_condition_lock: self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) @@ -470,13 +503,13 @@ class WorkerBase(ABC): return self.module_partition._topo else: return None - + def use_middleware(self): topo = self.get_topo() return topo is not None # TODO(jiangziyue) get single value instead of the whole output - def _get_real_args_kwargs(self, args_or_kwargs): + def _get_real_args_kwargs_fwd(self, args_or_kwargs): if not self.use_middleware(): args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) if args_or_kwargs is not None: @@ -491,8 +524,8 @@ class WorkerBase(ABC): if args_or_kwargs is not None: if isinstance(args_or_kwargs, dict): pass - else: - flatten_args = [] + else: + flatten_args = [] if self.is_first_stage(): pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) # TODO get by offset @@ -525,7 +558,7 @@ class WorkerBase(ABC): if stage_id == src_stage_id: src_index += i break - else: # data from input partition + else: # data from input partition src_index = 0 # when output_len = 1, not iterable if output_len == 1: @@ -536,6 +569,55 @@ class WorkerBase(ABC): args_or_kwargs = flatten_args return args_or_kwargs + # TODO(jiangziyue) get single value instead of the whole output + def _get_real_args_kwargs_bwd(self, args_or_kwargs): + if not self.use_middleware(): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + else: + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + flatten_args = [] + # TODO get by offset + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_vals = self_partition.get_output_vals() + consumer_stage_ids = self.get_consumer_stage_ids() + for val_list in output_vals: + # An output may be passed to many down stages. + target = None + for val_pos in val_list.get(): + dst_partition_id = val_pos.partition_id + dst_offset = val_pos.offset + dst_partition = topo.get_partition_by_id(dst_partition_id) + input_len = len(dst_partition.get_input_vals()) + dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo) + for i, stage_id in enumerate(consumer_stage_ids): + if stage_id == dst_stage_id: + dst_index = i + break + if input_len == 1: + part_grad = args_or_kwargs[dst_index] + else: + part_grad = args_or_kwargs[dst_index][dst_offset] + + if target is None: + target = part_grad + elif part_grad is not None: + target += part_grad + else: + continue + flatten_args.append(target) + args_or_kwargs = flatten_args + return args_or_kwargs + @abstractmethod def _get_work_item_key(self) -> UniqueKey: """ @@ -547,7 +629,7 @@ class WorkerBase(ABC): def is_last_stage(self): return self.pp_rank == self.actual_stage_num - 1 - + def need_model_input(self): need_input = False topo: Topo = self.get_topo() @@ -558,10 +640,13 @@ class WorkerBase(ABC): if model_input_partition_id in partition_inputs: need_input = True return not self.is_first_stage() and need_input - + def is_model_output(self): return self.is_last_stage() + def is_model_input(self): + return self.is_first_stage() + def _default_data_process_func(self, args_kwargs): if self.is_first_stage(): args = args_kwargs[0] @@ -598,11 +683,16 @@ class WorkerBase(ABC): # parse and integrate args and kwargs if is_first_stage: - args = self._get_real_args_kwargs(args) - kwargs = self._get_real_args_kwargs(kwargs) + args = self._get_real_args_kwargs_fwd(args) + kwargs = self._get_real_args_kwargs_fwd(kwargs) args_kwargs = (args, kwargs) else: - args_kwargs = self._get_real_args_kwargs(args) + args_kwargs = self._get_real_args_kwargs_fwd(args) + + if not forward_only: + pytree_map(args_kwargs, + lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False), + process_types=torch.Tensor) args, kwargs = data_process_func(args_kwargs) @@ -694,21 +784,40 @@ class WorkerBase(ABC): # overlap recompute and future.wait if not is_last_stage: - grad_tensors = self._get_real_args_kwargs(args) + grad_tensors = self._get_real_args_kwargs_bwd(args) else: grad_tensors = None # take tensor only (for only tensor can do backward) - stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor) - grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor) + # TODO(jiangziyue) : All values which should do bp are torch.Tensor? + stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor) + grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor) + + # output all input's grad to producer, even it has no grad(output None) + # to make the offset aligned to the topo's record. + if grad_tensors is not None: + filtered_outputs = [] + filtered_grads = [] + for i, grad in enumerate(grad_tensors): + stage_output = stage_outputs[i] + if stage_output.requires_grad and grad is not None: + filtered_outputs.append(stage_output) + filtered_grads.append(grad) + + stage_outputs = filtered_outputs + grad_tensors = filtered_grads autograd.backward(stage_outputs, grad_tensors=grad_tensors) # collect grad of input tensor consume_result = [] if not is_first_stage: - pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) - pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) + # In current design, input mush be a flatten args. + for arg in stage_input_args: + if isinstance(arg, torch.Tensor): + consume_result.append(arg.grad) + else: + consume_result.append(None) else: raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") @@ -740,11 +849,11 @@ class WorkerBase(ABC): def _hook_before_step(self): pass - def _reset_context(self): - self.forward_times = 0 - self.backward_times = 0 - self.outstanding = 0 - self._initialize_outstanding_range() + # install the main loop to wait for next batch input + def _wait_for_reset(self): + with self.reset_condition: + self.reset_condition.wait_for(lambda: self.reset) + self.reset = False # do the main loop to consume ready_list def _work_loop(self): @@ -755,10 +864,9 @@ class WorkerBase(ABC): # main loop while True: work_item_key = self._get_work_item_key() - # move current work item to output_list to activate subscribe in advance with self.work_list_condition_lock: - #self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list) + self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list) work_item = self.work_list[work_item_key] with self.output_list_condition_lock: @@ -768,16 +876,32 @@ class WorkerBase(ABC): consume_result = self._consume_work_item_by_phase(work_item) - work_item.output.set_result(consume_result) with self.work_list_condition_lock: self.work_list.pop(work_item_key) + work_item.output.set_result(consume_result) # if is last step in one batch reset context and do step if self._is_last_step(work_item): self._hook_before_step() if hasattr(self, 'optimizer') and not work_item.forward_only: self.step() - self._reset_context() + self._wait_for_reset() + + # reset context and resume loop + def reset_context(self): + self.forward_times = 0 + self.backward_times = 0 + self.outstanding = 0 + self._initialize_outstanding_range() + with self.work_list_condition_lock: + self.work_list.clear() + + with self.output_list_condition_lock: + self.output_list.clear() + + with self.reset_condition: + self.reset = True + self.reset_condition.notify_all() def initialize_optimizer(self, optimizer_class: type, **kwargs): # TODO(jiangziyue) it's temporary code to deal with empty module partition. @@ -856,7 +980,7 @@ class PipelineEngineBase(ABC, nn.Module): def _create_pp_rank_to_rpc_worker_id(self) -> None: """create a map from model partition to stage_id, which is useful when use_interleave is True. - e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then + e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part of partitions will be moved to device 0 and the others to device 1 """ @@ -947,7 +1071,7 @@ class PipelineEngineBase(ABC, nn.Module): key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) for pp_rank in input_pp_ranks: worker_rref = self.pp_rank_to_worker_rref[pp_rank] - worker_rref.rpc_sync().get_output_by_key(key) + worker_rref.rpc_sync().get_output_by_key(key, ref_use=True) def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: num_microbatches = self.num_microbatches @@ -965,6 +1089,7 @@ class PipelineEngineBase(ABC, nn.Module): # TODO : add relationship between output_pp_ranks and parts of microlabels worker_rref.remote().set_labels(microbatch_id, microlabels) + # TODO(jiangziyue) : get model output with single value, instead of merging into last stage. def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): key = UniqueKey(microbatch_id, Phase.FORWARD) for pp_rank in output_pp_ranks: @@ -993,6 +1118,16 @@ class PipelineEngineBase(ABC, nn.Module): return forward_result + def _reset_worker(self): + actual_stage_num = self._get_actual_stage_num() + for pp_rank in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + fut = worker_rref.rpc_async().reset_context() + self.step_futs.append(fut) + + for fut in self.step_futs: + fut.wait() + def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): batch_lengths = get_batch_lengths(batch) batch_length = batch_lengths[0] @@ -1046,6 +1181,7 @@ class PipelineEngineBase(ABC, nn.Module): worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref.rpc_sync().wait_for_step() + self._reset_worker() # reset worker attributes for next batch return forward_result def initialize_optimizer(self, optimizer_class: type, **kwargs): diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py index 0ab3a3694..555955583 100644 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -89,9 +89,6 @@ class OneFOneBWorker(WorkerBase): elif target_key.microbatch_id == num_microbatches - 1: self.outstanding_range = (0, 0) - with self.work_list_condition_lock: - self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) - return target_key diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/pipeline/rpc/utils.py index 361f6faf7..77d601173 100644 --- a/colossalai/pipeline/rpc/utils.py +++ b/colossalai/pipeline/rpc/utils.py @@ -57,7 +57,6 @@ def split_batch(batch: Any, start, stop, device: str): def type_detail(obj): return pytree_map(obj, lambda x: type(x), map_all=True) - def pytree_filter(fn, obj, process_types): if obj is None: return None diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index 853efde3f..7ce2cd433 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -31,7 +31,7 @@ class MLP(nn.Module): def forward(self, x): for layer in self.layers: x = layer(x) - return x + return x.sum() class DAG_MLP(nn.Module): def __init__(self, dim: int, layers: int): @@ -46,7 +46,7 @@ class DAG_MLP(nn.Module): for layer in self.layers: x = layer(x) y = self.dag_layer(y) - return x, y + return x.sum(), y.sum() class RpcTestModel(nn.Module): diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_pipeline/test_middleware_1f1b.py index c4fb9b094..c4dc617b1 100644 --- a/tests/test_pipeline/test_middleware_1f1b.py +++ b/tests/test_pipeline/test_middleware_1f1b.py @@ -41,10 +41,10 @@ def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) return partition -def run_master(model_cls, world_size): +def run_master(model_cls, world_size, forward_only): torch.manual_seed(100) - epoch = 10 + epoch = 3 device = 'cuda' stage_num = world_size chunk = 1 @@ -57,6 +57,10 @@ def run_master(model_cls, world_size): kwargs = dict(x=x) return kwargs model = model_cls(dim, stage_num * 3) + if forward_only: + labels = None + else: + labels = 1 elif model_cls == DAG_MLP: def data_gen(): x = torch.zeros((batch_size, dim)) @@ -64,24 +68,30 @@ def run_master(model_cls, world_size): kwargs = dict(x=x, y=y) return kwargs model = model_cls(dim, stage_num * 3) + if forward_only: + labels = None + else: + labels = 1 else: pass data_kwargs = data_gen() - + engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs), stage_num=stage_num, num_microbatches=num_microbatches, device=device, chunk=chunk, checkpoint=use_checkpoint,) + if not forward_only: + engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) for _ in range(epoch): input_x = torch.randn((batch_size, dim), device=device) input_y = torch.randn((batch_size, dim), device=device) - logits = engine.forward_backward({'x': input_x, 'y': input_y}, forward_only=True) + logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) -def run_worker(rank, model_cls, world_size, master_func): +def run_worker(rank, model_cls, world_size, forward_only, master_func): master_addr = 'localhost' master_port = 29020 os.environ['MASTER_ADDR'] = master_addr @@ -99,19 +109,20 @@ def run_worker(rank, model_cls, world_size, master_func): # in rpc mode, only rank 0 is needed to be coded if rank == 0: - master_func(model_cls, world_size) + master_func(model_cls, world_size, forward_only) # barrier here if rpc_is_initialized(): rpc.shutdown() @pytest.mark.skip("skip due to CI torch version 1.11") @parameterize('model_cls', [MLP, DAG_MLP]) +@parameterize('forward_only', [True, False]) @pytest.mark.dist @rerun_if_address_is_in_use() -def test_pp_middleware_fwd(model_cls): +def test_pp_middleware_fwd(model_cls, forward_only): world_size = 4 master_func = run_master - mp.spawn(run_worker, args=(model_cls, world_size, master_func), nprocs=world_size) + mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size) if __name__ == "__main__": - test_pp_middleware_fwd() + test_pp_middleware_fwd() \ No newline at end of file