[PP Middleware] Add bwd and step for PP middleware (#2111)

* add bwd and step for PP middleware

* pre-commit

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
pull/2120/head
Ziyue Jiang 2022-12-12 12:40:03 +08:00 committed by GitHub
parent 8afc001f4f
commit 09d69e1c25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 225 additions and 82 deletions

View File

@ -8,20 +8,29 @@ from typing import Any, Callable, Dict, List, Tuple
import torch
import torch.distributed.rpc as rpc
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
split_batch, tensor_shape_list, type_detail)
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (
get_batch_lengths,
pytree_filter,
pytree_map,
split_batch,
tensor_shape_list,
type_detail,
)
class Phase(Enum):
FORWARD = 0
BACKWARD = 1
UPDATE = 2
INPUT = 3
class UniqueKey:
__slots__ = ('microbatch_id', 'phase')
microbatch_id: int
@ -134,6 +143,7 @@ class WorkerBase(ABC):
self.partition_args = partition_args
self.criterion = criterion
self.metric = metric
self.reset = False
# context to maintain loop
self._initialize_context_container()
@ -164,6 +174,7 @@ class WorkerBase(ABC):
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
self.reset_condition = threading.Condition(threading.Lock())
def _initialize_partition(self):
partition_fn = self.partition_fn
@ -182,20 +193,23 @@ class WorkerBase(ABC):
# construction of partition is executed after the registion of pp_rank_to_worker_rref
self._initialize_partition()
def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any:
# res_use works for lifecycle counter,
# if ref_use is True, lifecycle won't add.
def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any:
with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key]
self.output_list.pop(key)
output_work_item.refcount += 1
self.output_list.pop(key)
if not ref_use:
output_work_item.refcount += 1
refcount = output_work_item.refcount
output = output_work_item.output
if output_work_item.phase != Phase.INPUT:
if output_work_item.phase == Phase.FORWARD:
# lifecycle management for DAG scheduler
lifecycle = len(self.get_consumer_stage_ids())
if self.is_model_output(): # an extra reference for scheduler collecting results
if self.is_model_output(): # an extra reference for scheduler collecting results
lifecycle += 1
with self.output_list_condition_lock:
# all consumers have been satisfied, the work_item can be released
@ -203,14 +217,24 @@ class WorkerBase(ABC):
if refcount < lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
elif output_work_item.phase == Phase.BACKWARD:
lifecycle = len(self.get_producer_stage_ids())
if self._is_last_step(output_work_item):
lifecycle += 1 # an extra reference for scheduler collecting results
with self.output_list_condition_lock:
# all producers have been satisfied, the work_item can be released
# or put it into work list again.
if refcount < lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
else:
with self.output_list_condition_lock:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
if isinstance(output, Future):
output = output.wait()
return output
def get_parameters(self) -> List[torch.Tensor]:
@ -257,13 +281,13 @@ class WorkerBase(ABC):
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device()
if not self.use_middleware():
# make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only)
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
@ -284,14 +308,14 @@ class WorkerBase(ABC):
self_arg_lst.append(arg_lst[off])
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
self.num_microbatches, forward_only)
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# put input tensor which other nodes need into output_list as Phase.INPUT
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
self.num_microbatches, forward_only)
self.num_microbatches, forward_only)
with self.output_list_condition_lock:
self.output_list[recv_input_key] = work_item_remote
@ -317,7 +341,7 @@ class WorkerBase(ABC):
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
def _subscribe_producer(self, microbatch_id: int, forward_only: bool):
"""
You should call this function asynchronously
@ -336,7 +360,7 @@ class WorkerBase(ABC):
producer_stage_ids = self.get_producer_stage_ids()
producer_num = len(producer_stage_ids)
if self.need_model_input():
producer_num += 1 # for input partition
producer_num += 1 # for input partition
subscribe_forward_futures: List[Future] = [None] * producer_num
# TODO(jiangziyue) get single value instead of the whole output
@ -344,26 +368,28 @@ class WorkerBase(ABC):
producer_stage_id = 0
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
for i in range(0, producer_num-1):
for i in range(0, producer_num - 1):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key)
else:
for i in range(producer_num):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
microbatch_id, None, self.num_microbatches, forward_only)
return work_item_from_producer
# TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one.
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
key = UniqueKey(microbatch_id, Phase.FORWARD)
@ -377,20 +403,20 @@ class WorkerBase(ABC):
self.work_list[key] = work_item_from_producer
self.work_list_condition_lock.notify_all()
def subscribe_consumer(self, microbatch_id: int):
def _subscribe_consumer(self, microbatch_id: int):
"""
You should call this function asynchronously
"""
assert self.producer_stage_ids is not None
consumer_num = len(self.consumer_stage_ids)
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
stage_id = self.pp_rank
subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device()
if not self.use_middleware():
consumer_stage_ids = self.consumer_stage_ids
else:
consumer_stage_ids = self.get_consumer_stage_ids()
consumer_num = len(consumer_stage_ids)
subscribe_backward_futures: List[Future] = [None] * consumer_num
for i in range(consumer_num):
consumer_stage_id = self.consumer_stage_ids[i]
consumer_stage_id = consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
@ -399,13 +425,20 @@ class WorkerBase(ABC):
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False)
# add work_item to work_list
return work_item_from_consumer
def subscribe_consumer(self, microbatch_id: int):
key = UniqueKey(microbatch_id, Phase.BACKWARD)
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.BACKWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_consumer
self.work_list_condition_lock.notify_all()
if key not in self.work_list:
# On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
# can only be executed once for every producer-consumer stage pair, which is necessary
# to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
work_item_from_consumer = self._subscribe_consumer(microbatch_id)
self.work_list[key] = work_item_from_consumer
self.work_list_condition_lock.notify_all()
def get_producer_stage_ids(self):
producer_stage_ids = []
rank = self.pp_rank
@ -425,7 +458,7 @@ class WorkerBase(ABC):
if partition_id != model_input_partition_id:
producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
return producer_stage_ids
def get_consumer_stage_ids(self):
consumer_stage_ids = []
rank = self.pp_rank
@ -462,7 +495,7 @@ class WorkerBase(ABC):
for i, id in enumerate(partition_ids):
if id == partition_id:
return i
def get_topo(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
@ -470,13 +503,13 @@ class WorkerBase(ABC):
return self.module_partition._topo
else:
return None
def use_middleware(self):
topo = self.get_topo()
return topo is not None
# TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs(self, args_or_kwargs):
def _get_real_args_kwargs_fwd(self, args_or_kwargs):
if not self.use_middleware():
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
@ -491,8 +524,8 @@ class WorkerBase(ABC):
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
else:
flatten_args = []
if self.is_first_stage():
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
# TODO get by offset
@ -525,7 +558,7 @@ class WorkerBase(ABC):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if output_len == 1:
@ -536,6 +569,55 @@ class WorkerBase(ABC):
args_or_kwargs = flatten_args
return args_or_kwargs
# TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs_bwd(self, args_or_kwargs):
if not self.use_middleware():
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
else:
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
flatten_args = []
# TODO get by offset
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
output_vals = self_partition.get_output_vals()
consumer_stage_ids = self.get_consumer_stage_ids()
for val_list in output_vals:
# An output may be passed to many down stages.
target = None
for val_pos in val_list.get():
dst_partition_id = val_pos.partition_id
dst_offset = val_pos.offset
dst_partition = topo.get_partition_by_id(dst_partition_id)
input_len = len(dst_partition.get_input_vals())
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
for i, stage_id in enumerate(consumer_stage_ids):
if stage_id == dst_stage_id:
dst_index = i
break
if input_len == 1:
part_grad = args_or_kwargs[dst_index]
else:
part_grad = args_or_kwargs[dst_index][dst_offset]
if target is None:
target = part_grad
elif part_grad is not None:
target += part_grad
else:
continue
flatten_args.append(target)
args_or_kwargs = flatten_args
return args_or_kwargs
@abstractmethod
def _get_work_item_key(self) -> UniqueKey:
"""
@ -547,7 +629,7 @@ class WorkerBase(ABC):
def is_last_stage(self):
return self.pp_rank == self.actual_stage_num - 1
def need_model_input(self):
need_input = False
topo: Topo = self.get_topo()
@ -558,10 +640,13 @@ class WorkerBase(ABC):
if model_input_partition_id in partition_inputs:
need_input = True
return not self.is_first_stage() and need_input
def is_model_output(self):
return self.is_last_stage()
def is_model_input(self):
return self.is_first_stage()
def _default_data_process_func(self, args_kwargs):
if self.is_first_stage():
args = args_kwargs[0]
@ -598,11 +683,16 @@ class WorkerBase(ABC):
# parse and integrate args and kwargs
if is_first_stage:
args = self._get_real_args_kwargs(args)
kwargs = self._get_real_args_kwargs(kwargs)
args = self._get_real_args_kwargs_fwd(args)
kwargs = self._get_real_args_kwargs_fwd(kwargs)
args_kwargs = (args, kwargs)
else:
args_kwargs = self._get_real_args_kwargs(args)
args_kwargs = self._get_real_args_kwargs_fwd(args)
if not forward_only:
pytree_map(args_kwargs,
lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
process_types=torch.Tensor)
args, kwargs = data_process_func(args_kwargs)
@ -694,21 +784,40 @@ class WorkerBase(ABC):
# overlap recompute and future.wait
if not is_last_stage:
grad_tensors = self._get_real_args_kwargs(args)
grad_tensors = self._get_real_args_kwargs_bwd(args)
else:
grad_tensors = None
# take tensor only (for only tensor can do backward)
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)
# TODO(jiangziyue) : All values which should do bp are torch.Tensor?
stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor)
grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor)
# output all input's grad to producer, even it has no grad(output None)
# to make the offset aligned to the topo's record.
if grad_tensors is not None:
filtered_outputs = []
filtered_grads = []
for i, grad in enumerate(grad_tensors):
stage_output = stage_outputs[i]
if stage_output.requires_grad and grad is not None:
filtered_outputs.append(stage_output)
filtered_grads.append(grad)
stage_outputs = filtered_outputs
grad_tensors = filtered_grads
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor
consume_result = []
if not is_first_stage:
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
# In current design, input mush be a flatten args.
for arg in stage_input_args:
if isinstance(arg, torch.Tensor):
consume_result.append(arg.grad)
else:
consume_result.append(None)
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@ -740,11 +849,11 @@ class WorkerBase(ABC):
def _hook_before_step(self):
pass
def _reset_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
# install the main loop to wait for next batch input
def _wait_for_reset(self):
with self.reset_condition:
self.reset_condition.wait_for(lambda: self.reset)
self.reset = False
# do the main loop to consume ready_list
def _work_loop(self):
@ -755,10 +864,9 @@ class WorkerBase(ABC):
# main loop
while True:
work_item_key = self._get_work_item_key()
# move current work item to output_list to activate subscribe in advance
with self.work_list_condition_lock:
#self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
work_item = self.work_list[work_item_key]
with self.output_list_condition_lock:
@ -768,16 +876,32 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item)
work_item.output.set_result(consume_result)
with self.work_list_condition_lock:
self.work_list.pop(work_item_key)
work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step
if self._is_last_step(work_item):
self._hook_before_step()
if hasattr(self, 'optimizer') and not work_item.forward_only:
self.step()
self._reset_context()
self._wait_for_reset()
# reset context and resume loop
def reset_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
with self.work_list_condition_lock:
self.work_list.clear()
with self.output_list_condition_lock:
self.output_list.clear()
with self.reset_condition:
self.reset = True
self.reset_condition.notify_all()
def initialize_optimizer(self, optimizer_class: type, **kwargs):
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
@ -856,7 +980,7 @@ class PipelineEngineBase(ABC, nn.Module):
def _create_pp_rank_to_rpc_worker_id(self) -> None:
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
of partitions will be moved to device 0 and the others to device 1
"""
@ -947,7 +1071,7 @@ class PipelineEngineBase(ABC, nn.Module):
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().get_output_by_key(key)
worker_rref.rpc_sync().get_output_by_key(key, ref_use=True)
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches
@ -965,6 +1089,7 @@ class PipelineEngineBase(ABC, nn.Module):
# TODO : add relationship between output_pp_ranks and parts of microlabels
worker_rref.remote().set_labels(microbatch_id, microlabels)
# TODO(jiangziyue) : get model output with single value, instead of merging into last stage.
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
key = UniqueKey(microbatch_id, Phase.FORWARD)
for pp_rank in output_pp_ranks:
@ -993,6 +1118,16 @@ class PipelineEngineBase(ABC, nn.Module):
return forward_result
def _reset_worker(self):
actual_stage_num = self._get_actual_stage_num()
for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().reset_context()
self.step_futs.append(fut)
for fut in self.step_futs:
fut.wait()
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
batch_lengths = get_batch_lengths(batch)
batch_length = batch_lengths[0]
@ -1046,6 +1181,7 @@ class PipelineEngineBase(ABC, nn.Module):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().wait_for_step()
self._reset_worker() # reset worker attributes for next batch
return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs):

View File

@ -89,9 +89,6 @@ class OneFOneBWorker(WorkerBase):
elif target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key

View File

@ -57,7 +57,6 @@ def split_batch(batch: Any, start, stop, device: str):
def type_detail(obj):
return pytree_map(obj, lambda x: type(x), map_all=True)
def pytree_filter(fn, obj, process_types):
if obj is None:
return None

View File

@ -31,7 +31,7 @@ class MLP(nn.Module):
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
return x.sum()
class DAG_MLP(nn.Module):
def __init__(self, dim: int, layers: int):
@ -46,7 +46,7 @@ class DAG_MLP(nn.Module):
for layer in self.layers:
x = layer(x)
y = self.dag_layer(y)
return x, y
return x.sum(), y.sum()
class RpcTestModel(nn.Module):

View File

@ -41,10 +41,10 @@ def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
return partition
def run_master(model_cls, world_size):
def run_master(model_cls, world_size, forward_only):
torch.manual_seed(100)
epoch = 10
epoch = 3
device = 'cuda'
stage_num = world_size
chunk = 1
@ -57,6 +57,10 @@ def run_master(model_cls, world_size):
kwargs = dict(x=x)
return kwargs
model = model_cls(dim, stage_num * 3)
if forward_only:
labels = None
else:
labels = 1
elif model_cls == DAG_MLP:
def data_gen():
x = torch.zeros((batch_size, dim))
@ -64,24 +68,30 @@ def run_master(model_cls, world_size):
kwargs = dict(x=x, y=y)
return kwargs
model = model_cls(dim, stage_num * 3)
if forward_only:
labels = None
else:
labels = 1
else:
pass
data_kwargs = data_gen()
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs),
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
checkpoint=use_checkpoint,)
if not forward_only:
engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3)
for _ in range(epoch):
input_x = torch.randn((batch_size, dim), device=device)
input_y = torch.randn((batch_size, dim), device=device)
logits = engine.forward_backward({'x': input_x, 'y': input_y}, forward_only=True)
logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only)
def run_worker(rank, model_cls, world_size, master_func):
def run_worker(rank, model_cls, world_size, forward_only, master_func):
master_addr = 'localhost'
master_port = 29020
os.environ['MASTER_ADDR'] = master_addr
@ -99,19 +109,20 @@ def run_worker(rank, model_cls, world_size, master_func):
# in rpc mode, only rank 0 is needed to be coded
if rank == 0:
master_func(model_cls, world_size)
master_func(model_cls, world_size, forward_only)
# barrier here
if rpc_is_initialized():
rpc.shutdown()
@pytest.mark.skip("skip due to CI torch version 1.11")
@parameterize('model_cls', [MLP, DAG_MLP])
@parameterize('forward_only', [True, False])
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pp_middleware_fwd(model_cls):
def test_pp_middleware_fwd(model_cls, forward_only):
world_size = 4
master_func = run_master
mp.spawn(run_worker, args=(model_cls, world_size, master_func), nprocs=world_size)
mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size)
if __name__ == "__main__":
test_pp_middleware_fwd()
test_pp_middleware_fwd()