mirror of https://github.com/hpcaitech/ColossalAI
[PP Middleware] Add bwd and step for PP middleware (#2111)
* add bwd and step for PP middleware * pre-commit Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>pull/2120/head
parent
8afc001f4f
commit
09d69e1c25
|
@ -8,20 +8,29 @@ from typing import Any, Callable, Dict, List, Tuple
|
|||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
|
||||
split_batch, tensor_shape_list, type_detail)
|
||||
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from torch import autograd, nn, optim
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (
|
||||
get_batch_lengths,
|
||||
pytree_filter,
|
||||
pytree_map,
|
||||
split_batch,
|
||||
tensor_shape_list,
|
||||
type_detail,
|
||||
)
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
FORWARD = 0
|
||||
BACKWARD = 1
|
||||
UPDATE = 2
|
||||
INPUT = 3
|
||||
|
||||
|
||||
class UniqueKey:
|
||||
__slots__ = ('microbatch_id', 'phase')
|
||||
microbatch_id: int
|
||||
|
@ -134,6 +143,7 @@ class WorkerBase(ABC):
|
|||
self.partition_args = partition_args
|
||||
self.criterion = criterion
|
||||
self.metric = metric
|
||||
self.reset = False
|
||||
|
||||
# context to maintain loop
|
||||
self._initialize_context_container()
|
||||
|
@ -164,6 +174,7 @@ class WorkerBase(ABC):
|
|||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.label_lock = threading.Condition(threading.Lock())
|
||||
self.reset_condition = threading.Condition(threading.Lock())
|
||||
|
||||
def _initialize_partition(self):
|
||||
partition_fn = self.partition_fn
|
||||
|
@ -182,20 +193,23 @@ class WorkerBase(ABC):
|
|||
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
||||
self._initialize_partition()
|
||||
|
||||
def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any:
|
||||
# res_use works for lifecycle counter,
|
||||
# if ref_use is True, lifecycle won't add.
|
||||
def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any:
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||
output_work_item = self.output_list[key]
|
||||
self.output_list.pop(key)
|
||||
|
||||
output_work_item.refcount += 1
|
||||
self.output_list.pop(key)
|
||||
|
||||
if not ref_use:
|
||||
output_work_item.refcount += 1
|
||||
refcount = output_work_item.refcount
|
||||
output = output_work_item.output
|
||||
|
||||
if output_work_item.phase != Phase.INPUT:
|
||||
if output_work_item.phase == Phase.FORWARD:
|
||||
# lifecycle management for DAG scheduler
|
||||
lifecycle = len(self.get_consumer_stage_ids())
|
||||
if self.is_model_output(): # an extra reference for scheduler collecting results
|
||||
if self.is_model_output(): # an extra reference for scheduler collecting results
|
||||
lifecycle += 1
|
||||
with self.output_list_condition_lock:
|
||||
# all consumers have been satisfied, the work_item can be released
|
||||
|
@ -203,14 +217,24 @@ class WorkerBase(ABC):
|
|||
if refcount < lifecycle:
|
||||
self.output_list[key] = output_work_item
|
||||
self.output_list_condition_lock.notify_all()
|
||||
elif output_work_item.phase == Phase.BACKWARD:
|
||||
lifecycle = len(self.get_producer_stage_ids())
|
||||
if self._is_last_step(output_work_item):
|
||||
lifecycle += 1 # an extra reference for scheduler collecting results
|
||||
with self.output_list_condition_lock:
|
||||
# all producers have been satisfied, the work_item can be released
|
||||
# or put it into work list again.
|
||||
if refcount < lifecycle:
|
||||
self.output_list[key] = output_work_item
|
||||
self.output_list_condition_lock.notify_all()
|
||||
else:
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list[key] = output_work_item
|
||||
self.output_list_condition_lock.notify_all()
|
||||
|
||||
|
||||
if isinstance(output, Future):
|
||||
output = output.wait()
|
||||
|
||||
|
||||
return output
|
||||
|
||||
def get_parameters(self) -> List[torch.Tensor]:
|
||||
|
@ -257,13 +281,13 @@ class WorkerBase(ABC):
|
|||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
output = self._get_future_by_device()
|
||||
|
||||
|
||||
if not self.use_middleware():
|
||||
# make args and kwargs
|
||||
args, kwargs = self._make_args_kwargs(microbatch)
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
@ -284,14 +308,14 @@ class WorkerBase(ABC):
|
|||
self_arg_lst.append(arg_lst[off])
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# put input tensor which other nodes need into output_list as Phase.INPUT
|
||||
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
self.num_microbatches, forward_only)
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list[recv_input_key] = work_item_remote
|
||||
|
@ -317,7 +341,7 @@ class WorkerBase(ABC):
|
|||
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
|
||||
def _subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
||||
"""
|
||||
You should call this function asynchronously
|
||||
|
@ -336,7 +360,7 @@ class WorkerBase(ABC):
|
|||
producer_stage_ids = self.get_producer_stage_ids()
|
||||
producer_num = len(producer_stage_ids)
|
||||
if self.need_model_input():
|
||||
producer_num += 1 # for input partition
|
||||
producer_num += 1 # for input partition
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
|
@ -344,26 +368,28 @@ class WorkerBase(ABC):
|
|||
producer_stage_id = 0
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||
|
||||
for i in range(0, producer_num-1):
|
||||
for i in range(0, producer_num - 1):
|
||||
producer_stage_id = producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||
subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key)
|
||||
|
||||
else:
|
||||
for i in range(producer_num):
|
||||
producer_stage_id = producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key)
|
||||
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
|
||||
return work_item_from_producer
|
||||
|
||||
|
||||
# TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one.
|
||||
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
|
@ -377,20 +403,20 @@ class WorkerBase(ABC):
|
|||
self.work_list[key] = work_item_from_producer
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def subscribe_consumer(self, microbatch_id: int):
|
||||
def _subscribe_consumer(self, microbatch_id: int):
|
||||
"""
|
||||
You should call this function asynchronously
|
||||
"""
|
||||
assert self.producer_stage_ids is not None
|
||||
consumer_num = len(self.consumer_stage_ids)
|
||||
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
|
||||
|
||||
stage_id = self.pp_rank
|
||||
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
||||
output = self._get_future_by_device()
|
||||
|
||||
if not self.use_middleware():
|
||||
consumer_stage_ids = self.consumer_stage_ids
|
||||
else:
|
||||
consumer_stage_ids = self.get_consumer_stage_ids()
|
||||
consumer_num = len(consumer_stage_ids)
|
||||
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
||||
for i in range(consumer_num):
|
||||
consumer_stage_id = self.consumer_stage_ids[i]
|
||||
consumer_stage_id = consumer_stage_ids[i]
|
||||
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
||||
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
|
||||
|
@ -399,13 +425,20 @@ class WorkerBase(ABC):
|
|||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, False)
|
||||
|
||||
# add work_item to work_list
|
||||
return work_item_from_consumer
|
||||
|
||||
def subscribe_consumer(self, microbatch_id: int):
|
||||
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
with self.work_list_condition_lock:
|
||||
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_consumer
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
if key not in self.work_list:
|
||||
# On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
|
||||
# can only be executed once for every producer-consumer stage pair, which is necessary
|
||||
# to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
|
||||
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
|
||||
work_item_from_consumer = self._subscribe_consumer(microbatch_id)
|
||||
self.work_list[key] = work_item_from_consumer
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def get_producer_stage_ids(self):
|
||||
producer_stage_ids = []
|
||||
rank = self.pp_rank
|
||||
|
@ -425,7 +458,7 @@ class WorkerBase(ABC):
|
|||
if partition_id != model_input_partition_id:
|
||||
producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
|
||||
return producer_stage_ids
|
||||
|
||||
|
||||
def get_consumer_stage_ids(self):
|
||||
consumer_stage_ids = []
|
||||
rank = self.pp_rank
|
||||
|
@ -462,7 +495,7 @@ class WorkerBase(ABC):
|
|||
for i, id in enumerate(partition_ids):
|
||||
if id == partition_id:
|
||||
return i
|
||||
|
||||
|
||||
def get_topo(self):
|
||||
with self.partition_condition_lock:
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
|
@ -470,13 +503,13 @@ class WorkerBase(ABC):
|
|||
return self.module_partition._topo
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def use_middleware(self):
|
||||
topo = self.get_topo()
|
||||
return topo is not None
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
def _get_real_args_kwargs(self, args_or_kwargs):
|
||||
def _get_real_args_kwargs_fwd(self, args_or_kwargs):
|
||||
if not self.use_middleware():
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
if args_or_kwargs is not None:
|
||||
|
@ -491,8 +524,8 @@ class WorkerBase(ABC):
|
|||
if args_or_kwargs is not None:
|
||||
if isinstance(args_or_kwargs, dict):
|
||||
pass
|
||||
else:
|
||||
flatten_args = []
|
||||
else:
|
||||
flatten_args = []
|
||||
if self.is_first_stage():
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
# TODO get by offset
|
||||
|
@ -525,7 +558,7 @@ class WorkerBase(ABC):
|
|||
if stage_id == src_stage_id:
|
||||
src_index += i
|
||||
break
|
||||
else: # data from input partition
|
||||
else: # data from input partition
|
||||
src_index = 0
|
||||
# when output_len = 1, not iterable
|
||||
if output_len == 1:
|
||||
|
@ -536,6 +569,55 @@ class WorkerBase(ABC):
|
|||
args_or_kwargs = flatten_args
|
||||
return args_or_kwargs
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
def _get_real_args_kwargs_bwd(self, args_or_kwargs):
|
||||
if not self.use_middleware():
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
if args_or_kwargs is not None:
|
||||
if isinstance(args_or_kwargs, dict):
|
||||
pass
|
||||
else:
|
||||
flatten_args = []
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
args_or_kwargs = flatten_args
|
||||
else:
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
if args_or_kwargs is not None:
|
||||
flatten_args = []
|
||||
# TODO get by offset
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||
output_vals = self_partition.get_output_vals()
|
||||
consumer_stage_ids = self.get_consumer_stage_ids()
|
||||
for val_list in output_vals:
|
||||
# An output may be passed to many down stages.
|
||||
target = None
|
||||
for val_pos in val_list.get():
|
||||
dst_partition_id = val_pos.partition_id
|
||||
dst_offset = val_pos.offset
|
||||
dst_partition = topo.get_partition_by_id(dst_partition_id)
|
||||
input_len = len(dst_partition.get_input_vals())
|
||||
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
|
||||
for i, stage_id in enumerate(consumer_stage_ids):
|
||||
if stage_id == dst_stage_id:
|
||||
dst_index = i
|
||||
break
|
||||
if input_len == 1:
|
||||
part_grad = args_or_kwargs[dst_index]
|
||||
else:
|
||||
part_grad = args_or_kwargs[dst_index][dst_offset]
|
||||
|
||||
if target is None:
|
||||
target = part_grad
|
||||
elif part_grad is not None:
|
||||
target += part_grad
|
||||
else:
|
||||
continue
|
||||
flatten_args.append(target)
|
||||
args_or_kwargs = flatten_args
|
||||
return args_or_kwargs
|
||||
|
||||
@abstractmethod
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
"""
|
||||
|
@ -547,7 +629,7 @@ class WorkerBase(ABC):
|
|||
|
||||
def is_last_stage(self):
|
||||
return self.pp_rank == self.actual_stage_num - 1
|
||||
|
||||
|
||||
def need_model_input(self):
|
||||
need_input = False
|
||||
topo: Topo = self.get_topo()
|
||||
|
@ -558,10 +640,13 @@ class WorkerBase(ABC):
|
|||
if model_input_partition_id in partition_inputs:
|
||||
need_input = True
|
||||
return not self.is_first_stage() and need_input
|
||||
|
||||
|
||||
def is_model_output(self):
|
||||
return self.is_last_stage()
|
||||
|
||||
def is_model_input(self):
|
||||
return self.is_first_stage()
|
||||
|
||||
def _default_data_process_func(self, args_kwargs):
|
||||
if self.is_first_stage():
|
||||
args = args_kwargs[0]
|
||||
|
@ -598,11 +683,16 @@ class WorkerBase(ABC):
|
|||
|
||||
# parse and integrate args and kwargs
|
||||
if is_first_stage:
|
||||
args = self._get_real_args_kwargs(args)
|
||||
kwargs = self._get_real_args_kwargs(kwargs)
|
||||
args = self._get_real_args_kwargs_fwd(args)
|
||||
kwargs = self._get_real_args_kwargs_fwd(kwargs)
|
||||
args_kwargs = (args, kwargs)
|
||||
else:
|
||||
args_kwargs = self._get_real_args_kwargs(args)
|
||||
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
||||
|
||||
if not forward_only:
|
||||
pytree_map(args_kwargs,
|
||||
lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
|
||||
process_types=torch.Tensor)
|
||||
|
||||
args, kwargs = data_process_func(args_kwargs)
|
||||
|
||||
|
@ -694,21 +784,40 @@ class WorkerBase(ABC):
|
|||
|
||||
# overlap recompute and future.wait
|
||||
if not is_last_stage:
|
||||
grad_tensors = self._get_real_args_kwargs(args)
|
||||
grad_tensors = self._get_real_args_kwargs_bwd(args)
|
||||
else:
|
||||
grad_tensors = None
|
||||
|
||||
# take tensor only (for only tensor can do backward)
|
||||
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
|
||||
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)
|
||||
# TODO(jiangziyue) : All values which should do bp are torch.Tensor?
|
||||
stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor)
|
||||
grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor)
|
||||
|
||||
# output all input's grad to producer, even it has no grad(output None)
|
||||
# to make the offset aligned to the topo's record.
|
||||
if grad_tensors is not None:
|
||||
filtered_outputs = []
|
||||
filtered_grads = []
|
||||
for i, grad in enumerate(grad_tensors):
|
||||
stage_output = stage_outputs[i]
|
||||
if stage_output.requires_grad and grad is not None:
|
||||
filtered_outputs.append(stage_output)
|
||||
filtered_grads.append(grad)
|
||||
|
||||
stage_outputs = filtered_outputs
|
||||
grad_tensors = filtered_grads
|
||||
|
||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||
|
||||
# collect grad of input tensor
|
||||
consume_result = []
|
||||
if not is_first_stage:
|
||||
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||
# In current design, input mush be a flatten args.
|
||||
for arg in stage_input_args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
consume_result.append(arg.grad)
|
||||
else:
|
||||
consume_result.append(None)
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||
|
@ -740,11 +849,11 @@ class WorkerBase(ABC):
|
|||
def _hook_before_step(self):
|
||||
pass
|
||||
|
||||
def _reset_context(self):
|
||||
self.forward_times = 0
|
||||
self.backward_times = 0
|
||||
self.outstanding = 0
|
||||
self._initialize_outstanding_range()
|
||||
# install the main loop to wait for next batch input
|
||||
def _wait_for_reset(self):
|
||||
with self.reset_condition:
|
||||
self.reset_condition.wait_for(lambda: self.reset)
|
||||
self.reset = False
|
||||
|
||||
# do the main loop to consume ready_list
|
||||
def _work_loop(self):
|
||||
|
@ -755,10 +864,9 @@ class WorkerBase(ABC):
|
|||
# main loop
|
||||
while True:
|
||||
work_item_key = self._get_work_item_key()
|
||||
|
||||
# move current work item to output_list to activate subscribe in advance
|
||||
with self.work_list_condition_lock:
|
||||
#self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
|
||||
self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
|
||||
work_item = self.work_list[work_item_key]
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
|
@ -768,16 +876,32 @@ class WorkerBase(ABC):
|
|||
|
||||
consume_result = self._consume_work_item_by_phase(work_item)
|
||||
|
||||
work_item.output.set_result(consume_result)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list.pop(work_item_key)
|
||||
work_item.output.set_result(consume_result)
|
||||
|
||||
# if is last step in one batch reset context and do step
|
||||
if self._is_last_step(work_item):
|
||||
self._hook_before_step()
|
||||
if hasattr(self, 'optimizer') and not work_item.forward_only:
|
||||
self.step()
|
||||
self._reset_context()
|
||||
self._wait_for_reset()
|
||||
|
||||
# reset context and resume loop
|
||||
def reset_context(self):
|
||||
self.forward_times = 0
|
||||
self.backward_times = 0
|
||||
self.outstanding = 0
|
||||
self._initialize_outstanding_range()
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list.clear()
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list.clear()
|
||||
|
||||
with self.reset_condition:
|
||||
self.reset = True
|
||||
self.reset_condition.notify_all()
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
||||
|
@ -856,7 +980,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
|
||||
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
||||
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
|
||||
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
|
||||
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
|
||||
pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
|
||||
of partitions will be moved to device 0 and the others to device 1
|
||||
"""
|
||||
|
@ -947,7 +1071,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
||||
for pp_rank in input_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_sync().get_output_by_key(key)
|
||||
worker_rref.rpc_sync().get_output_by_key(key, ref_use=True)
|
||||
|
||||
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
||||
num_microbatches = self.num_microbatches
|
||||
|
@ -965,6 +1089,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
# TODO : add relationship between output_pp_ranks and parts of microlabels
|
||||
worker_rref.remote().set_labels(microbatch_id, microlabels)
|
||||
|
||||
# TODO(jiangziyue) : get model output with single value, instead of merging into last stage.
|
||||
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
for pp_rank in output_pp_ranks:
|
||||
|
@ -993,6 +1118,16 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
|
||||
return forward_result
|
||||
|
||||
def _reset_worker(self):
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
for pp_rank in range(actual_stage_num):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
fut = worker_rref.rpc_async().reset_context()
|
||||
self.step_futs.append(fut)
|
||||
|
||||
for fut in self.step_futs:
|
||||
fut.wait()
|
||||
|
||||
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
||||
batch_lengths = get_batch_lengths(batch)
|
||||
batch_length = batch_lengths[0]
|
||||
|
@ -1046,6 +1181,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_sync().wait_for_step()
|
||||
|
||||
self._reset_worker() # reset worker attributes for next batch
|
||||
return forward_result
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
|
|
|
@ -89,9 +89,6 @@ class OneFOneBWorker(WorkerBase):
|
|||
elif target_key.microbatch_id == num_microbatches - 1:
|
||||
self.outstanding_range = (0, 0)
|
||||
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
||||
|
||||
return target_key
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,6 @@ def split_batch(batch: Any, start, stop, device: str):
|
|||
def type_detail(obj):
|
||||
return pytree_map(obj, lambda x: type(x), map_all=True)
|
||||
|
||||
|
||||
def pytree_filter(fn, obj, process_types):
|
||||
if obj is None:
|
||||
return None
|
||||
|
|
|
@ -31,7 +31,7 @@ class MLP(nn.Module):
|
|||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
return x.sum()
|
||||
|
||||
class DAG_MLP(nn.Module):
|
||||
def __init__(self, dim: int, layers: int):
|
||||
|
@ -46,7 +46,7 @@ class DAG_MLP(nn.Module):
|
|||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
y = self.dag_layer(y)
|
||||
return x, y
|
||||
return x.sum(), y.sum()
|
||||
|
||||
class RpcTestModel(nn.Module):
|
||||
|
||||
|
|
|
@ -41,10 +41,10 @@ def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int
|
|||
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
||||
return partition
|
||||
|
||||
def run_master(model_cls, world_size):
|
||||
def run_master(model_cls, world_size, forward_only):
|
||||
torch.manual_seed(100)
|
||||
|
||||
epoch = 10
|
||||
epoch = 3
|
||||
device = 'cuda'
|
||||
stage_num = world_size
|
||||
chunk = 1
|
||||
|
@ -57,6 +57,10 @@ def run_master(model_cls, world_size):
|
|||
kwargs = dict(x=x)
|
||||
return kwargs
|
||||
model = model_cls(dim, stage_num * 3)
|
||||
if forward_only:
|
||||
labels = None
|
||||
else:
|
||||
labels = 1
|
||||
elif model_cls == DAG_MLP:
|
||||
def data_gen():
|
||||
x = torch.zeros((batch_size, dim))
|
||||
|
@ -64,24 +68,30 @@ def run_master(model_cls, world_size):
|
|||
kwargs = dict(x=x, y=y)
|
||||
return kwargs
|
||||
model = model_cls(dim, stage_num * 3)
|
||||
if forward_only:
|
||||
labels = None
|
||||
else:
|
||||
labels = 1
|
||||
else:
|
||||
pass
|
||||
|
||||
data_kwargs = data_gen()
|
||||
|
||||
|
||||
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
checkpoint=use_checkpoint,)
|
||||
if not forward_only:
|
||||
engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3)
|
||||
|
||||
for _ in range(epoch):
|
||||
input_x = torch.randn((batch_size, dim), device=device)
|
||||
input_y = torch.randn((batch_size, dim), device=device)
|
||||
logits = engine.forward_backward({'x': input_x, 'y': input_y}, forward_only=True)
|
||||
logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only)
|
||||
|
||||
def run_worker(rank, model_cls, world_size, master_func):
|
||||
def run_worker(rank, model_cls, world_size, forward_only, master_func):
|
||||
master_addr = 'localhost'
|
||||
master_port = 29020
|
||||
os.environ['MASTER_ADDR'] = master_addr
|
||||
|
@ -99,19 +109,20 @@ def run_worker(rank, model_cls, world_size, master_func):
|
|||
|
||||
# in rpc mode, only rank 0 is needed to be coded
|
||||
if rank == 0:
|
||||
master_func(model_cls, world_size)
|
||||
master_func(model_cls, world_size, forward_only)
|
||||
# barrier here
|
||||
if rpc_is_initialized():
|
||||
rpc.shutdown()
|
||||
|
||||
@pytest.mark.skip("skip due to CI torch version 1.11")
|
||||
@parameterize('model_cls', [MLP, DAG_MLP])
|
||||
@parameterize('forward_only', [True, False])
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pp_middleware_fwd(model_cls):
|
||||
def test_pp_middleware_fwd(model_cls, forward_only):
|
||||
world_size = 4
|
||||
master_func = run_master
|
||||
mp.spawn(run_worker, args=(model_cls, world_size, master_func), nprocs=world_size)
|
||||
mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pp_middleware_fwd()
|
||||
test_pp_middleware_fwd()
|
Loading…
Reference in New Issue