[Pipeline Middleware] Reduce comm redundancy by getting accurate output (#2232)

* move to cpu to avoid dead lock

* get output by offsets

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
pull/2272/head
Ziyue Jiang 2023-01-03 13:43:57 +08:00 committed by GitHub
parent 09c0102fe6
commit 8b045b3c1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 152 additions and 46 deletions

View File

@ -185,18 +185,7 @@ class WorkerBase(ABC):
self.module_partition: nn.Module = partition_fn(*partition_args).to(device)
self.partition_condition_lock.notify_all()
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
# for some schedule need the other worker's info to initialise partition (like Chimera)
# construction of partition is executed after the registion of pp_rank_to_worker_rref
self._initialize_partition()
# res_use works for lifecycle counter,
# if ref_use is True, lifecycle won't add.
def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any:
def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None):
with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key]
@ -214,7 +203,8 @@ class WorkerBase(ABC):
lifecycle += 1
elif output_work_item.phase == Phase.BACKWARD:
lifecycle = len(self.get_producer_stage_ids())
if self._is_last_step(output_work_item): # an extra reference for ensure_backward
if self.is_model_input() and self._is_last_step(
output_work_item): # an extra reference for ensure_backward
lifecycle += 1
else:
lifecycle = 0
@ -230,6 +220,26 @@ class WorkerBase(ABC):
return output
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
# for some schedule need the other worker's info to initialise partition (like Chimera)
# construction of partition is executed after the registion of pp_rank_to_worker_rref
self._initialize_partition()
# res_use works for lifecycle counter,
# if ref_use is True, lifecycle won't add.
# offset supports get partial output to reduce comm costs.
def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any:
output = self._get_output_all(key, ref_use, rank)
if offsets is None: # get all for non iterable output
return output
else: # get part for iterable output
output = [output[i] for i in offsets]
return output
def get_parameters(self) -> List[torch.Tensor]:
return [p for p in self.module_partition.parameters()]
@ -361,22 +371,35 @@ class WorkerBase(ABC):
producer_stage_id = 0
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
offsets = self._get_input_offsets_by_index(target_index=0)
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key,
rank=self.pp_rank,
offsets=offsets)
for i in range(0, producer_num - 1):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key)
target_index = i + 1
offsets = self._get_input_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank)
else:
for i in range(producer_num):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key)
target_index = i
offsets = self._get_input_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank, offsets=offsets)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
@ -412,7 +435,13 @@ class WorkerBase(ABC):
consumer_stage_id = consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
target_index = i
offsets = self._get_output_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_backward_futures[target_index] = []
else:
subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key(
consumer_output_key, rank=self.pp_rank, offsets=offsets)
# flatten args
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
@ -501,6 +530,75 @@ class WorkerBase(ABC):
topo = self.get_topo()
return topo is not None
def _get_input_offsets_by_index(self, target_index):
res = []
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
model_input_partition_id = topo.get_input_partition_id()
input_vals = self_partition.get_input_vals()
producer_stage_ids = self.get_producer_stage_ids()
if self.need_model_input():
# 0 for data from input batch
# >= 1 for data from prev stages
base = 1
else:
# data from prev stages
base = 0
for val in input_vals:
val_pos = val.get()
src_partition_id = val_pos.partition_id
src_offset = val_pos.offset
src_index = base
src_partition = topo.get_partition_by_id(src_partition_id)
output_len = len(src_partition.get_output_vals())
# data from not-input partition
if src_partition_id != model_input_partition_id:
src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo)
src_index = base
for i, stage_id in enumerate(producer_stage_ids):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if target_index == src_index:
if output_len == 1:
res = None # offset = None to get all outputs
return res
else:
res.append(src_offset)
return res
def _get_output_offsets_by_index(self, target_index):
res = []
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
output_vals = self_partition.get_output_vals()
consumer_stage_ids = self.get_consumer_stage_ids()
for val_list in output_vals:
# An output may be passed to many down stages.
target = None
for val_pos in val_list.get():
dst_partition_id = val_pos.partition_id
dst_offset = val_pos.offset
dst_partition = topo.get_partition_by_id(dst_partition_id)
input_len = len(dst_partition.get_input_vals())
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
for i, stage_id in enumerate(consumer_stage_ids):
if stage_id == dst_stage_id:
dst_index = i
break
if target_index == dst_index:
if input_len == 1:
res = None # offset = None to get all outputs
return res
else:
res.append(dst_offset)
return res
# TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs_fwd(self, args_or_kwargs):
if not self.use_middleware():
@ -521,8 +619,7 @@ class WorkerBase(ABC):
flatten_args = []
if self.is_first_stage():
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
# TODO get by offset
else:
else: # get by offset
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
@ -557,7 +654,9 @@ class WorkerBase(ABC):
if output_len == 1:
target = args_or_kwargs[src_index]
else:
target = args_or_kwargs[src_index][src_offset]
offsets = self._get_input_offsets_by_index(src_index)
real_offset = offsets.index(src_offset)
target = args_or_kwargs[src_index][real_offset]
flatten_args.append(target)
args_or_kwargs = flatten_args
return args_or_kwargs
@ -574,10 +673,10 @@ class WorkerBase(ABC):
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
else:
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
for i, arg in enumerate(args_or_kwargs):
args_or_kwargs[i] = arg.wait()
if args_or_kwargs is not None: # get by offset
flatten_args = []
# TODO get by offset
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
@ -599,7 +698,9 @@ class WorkerBase(ABC):
if input_len == 1:
part_grad = args_or_kwargs[dst_index]
else:
part_grad = args_or_kwargs[dst_index][dst_offset]
offsets = self._get_output_offsets_by_index(dst_index)
real_offsets = offsets.index(dst_offset)
part_grad = args_or_kwargs[dst_index][real_offsets]
if target is None:
target = part_grad
@ -682,10 +783,6 @@ class WorkerBase(ABC):
else:
args_kwargs = self._get_real_args_kwargs_fwd(args)
# if not forward_only:
# pytree_map(args_kwargs,
# lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
# process_types=torch.Tensor)
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
@ -752,14 +849,14 @@ class WorkerBase(ABC):
stage_input_kwargs,
stage_outputs,
checkpoint=use_checkpoint)
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in
# if not forward_only, do the backward
if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic
self._begin_backward(microbatch_id)
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
elif phase == Phase.BACKWARD:
# remind its producer to get data before backward
if not is_first_stage:
@ -803,10 +900,8 @@ class WorkerBase(ABC):
filtered_grads.append(grad)
stage_outputs = filtered_outputs
grad_tensors = filtered_grads
grad_tensors = pyobj_map(grad_tensors, fn=lambda x: x.to(self.device),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor
@ -941,8 +1036,6 @@ class PipelineEngineBase(ABC, nn.Module):
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
self.step_futs: List[Future] = []
self._check_argument()
self._create_pp_rank_to_rpc_worker_id()
self._create_pp_rank_to_module_partition_id()
@ -1058,9 +1151,14 @@ class PipelineEngineBase(ABC, nn.Module):
ret_future[pp_rank][microbatch_id - actual_stage_num].wait()
else:
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
futs = []
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().get_output_by_key(key, ref_use=True)
fut = worker_rref.rpc_async().get_output_by_key(key, ref_use=True, offsets=[])
futs.append(fut)
for fut in futs:
fut.wait()
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches
@ -1087,10 +1185,16 @@ class PipelineEngineBase(ABC, nn.Module):
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
if not forward_only:
backward_result = []
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
worker_rref.rpc_sync().get_output_by_key(key)
fut = worker_rref.rpc_async().get_output_by_key(
key, offsets=[]) # only ensure the res exists, no need for real data.
backward_result.append(fut)
for fut in backward_result:
fut.wait()
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
forward_result = []
@ -1109,12 +1213,13 @@ class PipelineEngineBase(ABC, nn.Module):
def _reset_worker(self):
actual_stage_num = self._get_actual_stage_num()
reset_futs: List[Future] = []
for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().reset_context()
self.step_futs.append(fut)
reset_futs.append(fut)
for fut in self.step_futs:
for fut in reset_futs:
fut.wait()
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
@ -1141,7 +1246,7 @@ class PipelineEngineBase(ABC, nn.Module):
for microbatch_id in range(num_microbatches):
# control data input speed
# to prevent exceed of wait limitations
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
# self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
batch_start = microbatch_size * microbatch_id
batch_end = min(batch_start + microbatch_size, batch_length)
@ -1178,10 +1283,11 @@ class PipelineEngineBase(ABC, nn.Module):
def step(self):
actual_stage_num = self._get_actual_stage_num()
step_futs: List[Future] = []
for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().step()
self.step_futs.append(fut)
step_futs.append(fut)
for fut in self.step_futs:
for fut in step_futs:
fut.wait()