mirror of https://github.com/hpcaitech/ColossalAI
[Pipeline Middleware] Reduce comm redundancy by getting accurate output (#2232)
* move to cpu to avoid dead lock * get output by offsets Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>pull/2272/head
parent
09c0102fe6
commit
8b045b3c1f
|
@ -185,18 +185,7 @@ class WorkerBase(ABC):
|
||||||
self.module_partition: nn.Module = partition_fn(*partition_args).to(device)
|
self.module_partition: nn.Module = partition_fn(*partition_args).to(device)
|
||||||
self.partition_condition_lock.notify_all()
|
self.partition_condition_lock.notify_all()
|
||||||
|
|
||||||
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None):
|
||||||
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
|
||||||
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
|
|
||||||
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
|
|
||||||
|
|
||||||
# for some schedule need the other worker's info to initialise partition (like Chimera)
|
|
||||||
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
|
||||||
self._initialize_partition()
|
|
||||||
|
|
||||||
# res_use works for lifecycle counter,
|
|
||||||
# if ref_use is True, lifecycle won't add.
|
|
||||||
def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any:
|
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||||
output_work_item = self.output_list[key]
|
output_work_item = self.output_list[key]
|
||||||
|
@ -214,7 +203,8 @@ class WorkerBase(ABC):
|
||||||
lifecycle += 1
|
lifecycle += 1
|
||||||
elif output_work_item.phase == Phase.BACKWARD:
|
elif output_work_item.phase == Phase.BACKWARD:
|
||||||
lifecycle = len(self.get_producer_stage_ids())
|
lifecycle = len(self.get_producer_stage_ids())
|
||||||
if self._is_last_step(output_work_item): # an extra reference for ensure_backward
|
if self.is_model_input() and self._is_last_step(
|
||||||
|
output_work_item): # an extra reference for ensure_backward
|
||||||
lifecycle += 1
|
lifecycle += 1
|
||||||
else:
|
else:
|
||||||
lifecycle = 0
|
lifecycle = 0
|
||||||
|
@ -230,6 +220,26 @@ class WorkerBase(ABC):
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
||||||
|
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
||||||
|
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
|
||||||
|
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
|
||||||
|
|
||||||
|
# for some schedule need the other worker's info to initialise partition (like Chimera)
|
||||||
|
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
||||||
|
self._initialize_partition()
|
||||||
|
|
||||||
|
# res_use works for lifecycle counter,
|
||||||
|
# if ref_use is True, lifecycle won't add.
|
||||||
|
# offset supports get partial output to reduce comm costs.
|
||||||
|
def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any:
|
||||||
|
output = self._get_output_all(key, ref_use, rank)
|
||||||
|
if offsets is None: # get all for non iterable output
|
||||||
|
return output
|
||||||
|
else: # get part for iterable output
|
||||||
|
output = [output[i] for i in offsets]
|
||||||
|
return output
|
||||||
|
|
||||||
def get_parameters(self) -> List[torch.Tensor]:
|
def get_parameters(self) -> List[torch.Tensor]:
|
||||||
return [p for p in self.module_partition.parameters()]
|
return [p for p in self.module_partition.parameters()]
|
||||||
|
|
||||||
|
@ -361,22 +371,35 @@ class WorkerBase(ABC):
|
||||||
producer_stage_id = 0
|
producer_stage_id = 0
|
||||||
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
offsets = self._get_input_offsets_by_index(target_index=0)
|
||||||
|
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key,
|
||||||
|
rank=self.pp_rank,
|
||||||
|
offsets=offsets)
|
||||||
|
|
||||||
for i in range(0, producer_num - 1):
|
for i in range(0, producer_num - 1):
|
||||||
producer_stage_id = producer_stage_ids[i]
|
producer_stage_id = producer_stage_ids[i]
|
||||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key(
|
target_index = i + 1
|
||||||
producer_output_key)
|
offsets = self._get_input_offsets_by_index(target_index=target_index)
|
||||||
|
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||||
|
subscribe_forward_futures[target_index] = []
|
||||||
|
else:
|
||||||
|
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||||
|
producer_output_key, rank=self.pp_rank)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for i in range(producer_num):
|
for i in range(producer_num):
|
||||||
producer_stage_id = producer_stage_ids[i]
|
producer_stage_id = producer_stage_ids[i]
|
||||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(
|
target_index = i
|
||||||
producer_output_key)
|
offsets = self._get_input_offsets_by_index(target_index=target_index)
|
||||||
|
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||||
|
subscribe_forward_futures[target_index] = []
|
||||||
|
else:
|
||||||
|
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||||
|
producer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||||
|
|
||||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||||
microbatch_id, None, self.num_microbatches, forward_only)
|
microbatch_id, None, self.num_microbatches, forward_only)
|
||||||
|
@ -412,7 +435,13 @@ class WorkerBase(ABC):
|
||||||
consumer_stage_id = consumer_stage_ids[i]
|
consumer_stage_id = consumer_stage_ids[i]
|
||||||
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||||
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
||||||
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
|
target_index = i
|
||||||
|
offsets = self._get_output_offsets_by_index(target_index=target_index)
|
||||||
|
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||||
|
subscribe_backward_futures[target_index] = []
|
||||||
|
else:
|
||||||
|
subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key(
|
||||||
|
consumer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||||
|
|
||||||
# flatten args
|
# flatten args
|
||||||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
||||||
|
@ -501,6 +530,75 @@ class WorkerBase(ABC):
|
||||||
topo = self.get_topo()
|
topo = self.get_topo()
|
||||||
return topo is not None
|
return topo is not None
|
||||||
|
|
||||||
|
def _get_input_offsets_by_index(self, target_index):
|
||||||
|
res = []
|
||||||
|
topo: Topo = self.get_topo()
|
||||||
|
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||||
|
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||||
|
model_input_partition_id = topo.get_input_partition_id()
|
||||||
|
input_vals = self_partition.get_input_vals()
|
||||||
|
producer_stage_ids = self.get_producer_stage_ids()
|
||||||
|
if self.need_model_input():
|
||||||
|
# 0 for data from input batch
|
||||||
|
# >= 1 for data from prev stages
|
||||||
|
base = 1
|
||||||
|
else:
|
||||||
|
# data from prev stages
|
||||||
|
base = 0
|
||||||
|
for val in input_vals:
|
||||||
|
val_pos = val.get()
|
||||||
|
src_partition_id = val_pos.partition_id
|
||||||
|
src_offset = val_pos.offset
|
||||||
|
src_index = base
|
||||||
|
src_partition = topo.get_partition_by_id(src_partition_id)
|
||||||
|
output_len = len(src_partition.get_output_vals())
|
||||||
|
# data from not-input partition
|
||||||
|
if src_partition_id != model_input_partition_id:
|
||||||
|
src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo)
|
||||||
|
src_index = base
|
||||||
|
for i, stage_id in enumerate(producer_stage_ids):
|
||||||
|
if stage_id == src_stage_id:
|
||||||
|
src_index += i
|
||||||
|
break
|
||||||
|
else: # data from input partition
|
||||||
|
src_index = 0
|
||||||
|
# when output_len = 1, not iterable
|
||||||
|
if target_index == src_index:
|
||||||
|
if output_len == 1:
|
||||||
|
res = None # offset = None to get all outputs
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
res.append(src_offset)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _get_output_offsets_by_index(self, target_index):
|
||||||
|
res = []
|
||||||
|
topo: Topo = self.get_topo()
|
||||||
|
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||||
|
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||||
|
output_vals = self_partition.get_output_vals()
|
||||||
|
consumer_stage_ids = self.get_consumer_stage_ids()
|
||||||
|
for val_list in output_vals:
|
||||||
|
# An output may be passed to many down stages.
|
||||||
|
target = None
|
||||||
|
for val_pos in val_list.get():
|
||||||
|
dst_partition_id = val_pos.partition_id
|
||||||
|
dst_offset = val_pos.offset
|
||||||
|
dst_partition = topo.get_partition_by_id(dst_partition_id)
|
||||||
|
input_len = len(dst_partition.get_input_vals())
|
||||||
|
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
|
||||||
|
for i, stage_id in enumerate(consumer_stage_ids):
|
||||||
|
if stage_id == dst_stage_id:
|
||||||
|
dst_index = i
|
||||||
|
break
|
||||||
|
if target_index == dst_index:
|
||||||
|
if input_len == 1:
|
||||||
|
res = None # offset = None to get all outputs
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
res.append(dst_offset)
|
||||||
|
return res
|
||||||
|
|
||||||
# TODO(jiangziyue) get single value instead of the whole output
|
# TODO(jiangziyue) get single value instead of the whole output
|
||||||
def _get_real_args_kwargs_fwd(self, args_or_kwargs):
|
def _get_real_args_kwargs_fwd(self, args_or_kwargs):
|
||||||
if not self.use_middleware():
|
if not self.use_middleware():
|
||||||
|
@ -521,8 +619,7 @@ class WorkerBase(ABC):
|
||||||
flatten_args = []
|
flatten_args = []
|
||||||
if self.is_first_stage():
|
if self.is_first_stage():
|
||||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||||
# TODO get by offset
|
else: # get by offset
|
||||||
else:
|
|
||||||
topo: Topo = self.get_topo()
|
topo: Topo = self.get_topo()
|
||||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||||
|
@ -557,7 +654,9 @@ class WorkerBase(ABC):
|
||||||
if output_len == 1:
|
if output_len == 1:
|
||||||
target = args_or_kwargs[src_index]
|
target = args_or_kwargs[src_index]
|
||||||
else:
|
else:
|
||||||
target = args_or_kwargs[src_index][src_offset]
|
offsets = self._get_input_offsets_by_index(src_index)
|
||||||
|
real_offset = offsets.index(src_offset)
|
||||||
|
target = args_or_kwargs[src_index][real_offset]
|
||||||
flatten_args.append(target)
|
flatten_args.append(target)
|
||||||
args_or_kwargs = flatten_args
|
args_or_kwargs = flatten_args
|
||||||
return args_or_kwargs
|
return args_or_kwargs
|
||||||
|
@ -574,10 +673,10 @@ class WorkerBase(ABC):
|
||||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||||
args_or_kwargs = flatten_args
|
args_or_kwargs = flatten_args
|
||||||
else:
|
else:
|
||||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
for i, arg in enumerate(args_or_kwargs):
|
||||||
if args_or_kwargs is not None:
|
args_or_kwargs[i] = arg.wait()
|
||||||
|
if args_or_kwargs is not None: # get by offset
|
||||||
flatten_args = []
|
flatten_args = []
|
||||||
# TODO get by offset
|
|
||||||
topo: Topo = self.get_topo()
|
topo: Topo = self.get_topo()
|
||||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||||
|
@ -599,7 +698,9 @@ class WorkerBase(ABC):
|
||||||
if input_len == 1:
|
if input_len == 1:
|
||||||
part_grad = args_or_kwargs[dst_index]
|
part_grad = args_or_kwargs[dst_index]
|
||||||
else:
|
else:
|
||||||
part_grad = args_or_kwargs[dst_index][dst_offset]
|
offsets = self._get_output_offsets_by_index(dst_index)
|
||||||
|
real_offsets = offsets.index(dst_offset)
|
||||||
|
part_grad = args_or_kwargs[dst_index][real_offsets]
|
||||||
|
|
||||||
if target is None:
|
if target is None:
|
||||||
target = part_grad
|
target = part_grad
|
||||||
|
@ -682,10 +783,6 @@ class WorkerBase(ABC):
|
||||||
else:
|
else:
|
||||||
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
||||||
|
|
||||||
# if not forward_only:
|
|
||||||
# pytree_map(args_kwargs,
|
|
||||||
# lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
|
|
||||||
# process_types=torch.Tensor)
|
|
||||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
|
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
|
||||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||||
|
|
||||||
|
@ -752,14 +849,14 @@ class WorkerBase(ABC):
|
||||||
stage_input_kwargs,
|
stage_input_kwargs,
|
||||||
stage_outputs,
|
stage_outputs,
|
||||||
checkpoint=use_checkpoint)
|
checkpoint=use_checkpoint)
|
||||||
|
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
|
||||||
|
process_types=torch.Tensor) # torch rpc doesn't support args or rets in
|
||||||
|
|
||||||
# if not forward_only, do the backward
|
# if not forward_only, do the backward
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||||
self._begin_backward(microbatch_id)
|
self._begin_backward(microbatch_id)
|
||||||
|
|
||||||
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
|
|
||||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
|
||||||
|
|
||||||
elif phase == Phase.BACKWARD:
|
elif phase == Phase.BACKWARD:
|
||||||
# remind its producer to get data before backward
|
# remind its producer to get data before backward
|
||||||
if not is_first_stage:
|
if not is_first_stage:
|
||||||
|
@ -803,9 +900,7 @@ class WorkerBase(ABC):
|
||||||
filtered_grads.append(grad)
|
filtered_grads.append(grad)
|
||||||
|
|
||||||
stage_outputs = filtered_outputs
|
stage_outputs = filtered_outputs
|
||||||
grad_tensors = filtered_grads
|
grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device),
|
||||||
|
|
||||||
grad_tensors = pyobj_map(grad_tensors, fn=lambda x: x.to(self.device),
|
|
||||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||||
|
|
||||||
|
@ -941,8 +1036,6 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||||
|
|
||||||
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
|
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
|
||||||
|
|
||||||
self.step_futs: List[Future] = []
|
|
||||||
|
|
||||||
self._check_argument()
|
self._check_argument()
|
||||||
self._create_pp_rank_to_rpc_worker_id()
|
self._create_pp_rank_to_rpc_worker_id()
|
||||||
self._create_pp_rank_to_module_partition_id()
|
self._create_pp_rank_to_module_partition_id()
|
||||||
|
@ -1058,9 +1151,14 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||||
ret_future[pp_rank][microbatch_id - actual_stage_num].wait()
|
ret_future[pp_rank][microbatch_id - actual_stage_num].wait()
|
||||||
else:
|
else:
|
||||||
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
||||||
|
futs = []
|
||||||
for pp_rank in input_pp_ranks:
|
for pp_rank in input_pp_ranks:
|
||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
worker_rref.rpc_sync().get_output_by_key(key, ref_use=True)
|
fut = worker_rref.rpc_async().get_output_by_key(key, ref_use=True, offsets=[])
|
||||||
|
futs.append(fut)
|
||||||
|
|
||||||
|
for fut in futs:
|
||||||
|
fut.wait()
|
||||||
|
|
||||||
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
||||||
num_microbatches = self.num_microbatches
|
num_microbatches = self.num_microbatches
|
||||||
|
@ -1087,10 +1185,16 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||||
|
|
||||||
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
|
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
|
backward_result = []
|
||||||
for pp_rank in input_pp_ranks:
|
for pp_rank in input_pp_ranks:
|
||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
||||||
worker_rref.rpc_sync().get_output_by_key(key)
|
fut = worker_rref.rpc_async().get_output_by_key(
|
||||||
|
key, offsets=[]) # only ensure the res exists, no need for real data.
|
||||||
|
backward_result.append(fut)
|
||||||
|
|
||||||
|
for fut in backward_result:
|
||||||
|
fut.wait()
|
||||||
|
|
||||||
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
||||||
forward_result = []
|
forward_result = []
|
||||||
|
@ -1109,12 +1213,13 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||||
|
|
||||||
def _reset_worker(self):
|
def _reset_worker(self):
|
||||||
actual_stage_num = self._get_actual_stage_num()
|
actual_stage_num = self._get_actual_stage_num()
|
||||||
|
reset_futs: List[Future] = []
|
||||||
for pp_rank in range(actual_stage_num):
|
for pp_rank in range(actual_stage_num):
|
||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
fut = worker_rref.rpc_async().reset_context()
|
fut = worker_rref.rpc_async().reset_context()
|
||||||
self.step_futs.append(fut)
|
reset_futs.append(fut)
|
||||||
|
|
||||||
for fut in self.step_futs:
|
for fut in reset_futs:
|
||||||
fut.wait()
|
fut.wait()
|
||||||
|
|
||||||
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
||||||
|
@ -1141,7 +1246,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||||
for microbatch_id in range(num_microbatches):
|
for microbatch_id in range(num_microbatches):
|
||||||
# control data input speed
|
# control data input speed
|
||||||
# to prevent exceed of wait limitations
|
# to prevent exceed of wait limitations
|
||||||
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
|
# self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
|
||||||
batch_start = microbatch_size * microbatch_id
|
batch_start = microbatch_size * microbatch_id
|
||||||
batch_end = min(batch_start + microbatch_size, batch_length)
|
batch_end = min(batch_start + microbatch_size, batch_length)
|
||||||
|
|
||||||
|
@ -1178,10 +1283,11 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
actual_stage_num = self._get_actual_stage_num()
|
actual_stage_num = self._get_actual_stage_num()
|
||||||
|
step_futs: List[Future] = []
|
||||||
for pp_rank in range(actual_stage_num):
|
for pp_rank in range(actual_stage_num):
|
||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
fut = worker_rref.rpc_async().step()
|
fut = worker_rref.rpc_async().step()
|
||||||
self.step_futs.append(fut)
|
step_futs.append(fut)
|
||||||
|
|
||||||
for fut in self.step_futs:
|
for fut in step_futs:
|
||||||
fut.wait()
|
fut.wait()
|
||||||
|
|
Loading…
Reference in New Issue