mirror of https://github.com/hpcaitech/ColossalAI
[Pipeline Middleware] Reduce comm redundancy by getting accurate output (#2232)
* move to cpu to avoid dead lock * get output by offsets Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>pull/2272/head
parent
09c0102fe6
commit
8b045b3c1f
|
@ -185,18 +185,7 @@ class WorkerBase(ABC):
|
|||
self.module_partition: nn.Module = partition_fn(*partition_args).to(device)
|
||||
self.partition_condition_lock.notify_all()
|
||||
|
||||
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
||||
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
||||
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
|
||||
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
|
||||
|
||||
# for some schedule need the other worker's info to initialise partition (like Chimera)
|
||||
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
||||
self._initialize_partition()
|
||||
|
||||
# res_use works for lifecycle counter,
|
||||
# if ref_use is True, lifecycle won't add.
|
||||
def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any:
|
||||
def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None):
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||
output_work_item = self.output_list[key]
|
||||
|
@ -214,7 +203,8 @@ class WorkerBase(ABC):
|
|||
lifecycle += 1
|
||||
elif output_work_item.phase == Phase.BACKWARD:
|
||||
lifecycle = len(self.get_producer_stage_ids())
|
||||
if self._is_last_step(output_work_item): # an extra reference for ensure_backward
|
||||
if self.is_model_input() and self._is_last_step(
|
||||
output_work_item): # an extra reference for ensure_backward
|
||||
lifecycle += 1
|
||||
else:
|
||||
lifecycle = 0
|
||||
|
@ -230,6 +220,26 @@ class WorkerBase(ABC):
|
|||
|
||||
return output
|
||||
|
||||
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
||||
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
||||
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
|
||||
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
|
||||
|
||||
# for some schedule need the other worker's info to initialise partition (like Chimera)
|
||||
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
||||
self._initialize_partition()
|
||||
|
||||
# res_use works for lifecycle counter,
|
||||
# if ref_use is True, lifecycle won't add.
|
||||
# offset supports get partial output to reduce comm costs.
|
||||
def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any:
|
||||
output = self._get_output_all(key, ref_use, rank)
|
||||
if offsets is None: # get all for non iterable output
|
||||
return output
|
||||
else: # get part for iterable output
|
||||
output = [output[i] for i in offsets]
|
||||
return output
|
||||
|
||||
def get_parameters(self) -> List[torch.Tensor]:
|
||||
return [p for p in self.module_partition.parameters()]
|
||||
|
||||
|
@ -361,22 +371,35 @@ class WorkerBase(ABC):
|
|||
producer_stage_id = 0
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||
offsets = self._get_input_offsets_by_index(target_index=0)
|
||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key,
|
||||
rank=self.pp_rank,
|
||||
offsets=offsets)
|
||||
|
||||
for i in range(0, producer_num - 1):
|
||||
producer_stage_id = producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key)
|
||||
target_index = i + 1
|
||||
offsets = self._get_input_offsets_by_index(target_index=target_index)
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
subscribe_forward_futures[target_index] = []
|
||||
else:
|
||||
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key, rank=self.pp_rank)
|
||||
|
||||
else:
|
||||
for i in range(producer_num):
|
||||
producer_stage_id = producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key)
|
||||
target_index = i
|
||||
offsets = self._get_input_offsets_by_index(target_index=target_index)
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
subscribe_forward_futures[target_index] = []
|
||||
else:
|
||||
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
|
@ -412,7 +435,13 @@ class WorkerBase(ABC):
|
|||
consumer_stage_id = consumer_stage_ids[i]
|
||||
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
||||
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
|
||||
target_index = i
|
||||
offsets = self._get_output_offsets_by_index(target_index=target_index)
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
subscribe_backward_futures[target_index] = []
|
||||
else:
|
||||
subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key(
|
||||
consumer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||
|
||||
# flatten args
|
||||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
||||
|
@ -501,6 +530,75 @@ class WorkerBase(ABC):
|
|||
topo = self.get_topo()
|
||||
return topo is not None
|
||||
|
||||
def _get_input_offsets_by_index(self, target_index):
|
||||
res = []
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||
model_input_partition_id = topo.get_input_partition_id()
|
||||
input_vals = self_partition.get_input_vals()
|
||||
producer_stage_ids = self.get_producer_stage_ids()
|
||||
if self.need_model_input():
|
||||
# 0 for data from input batch
|
||||
# >= 1 for data from prev stages
|
||||
base = 1
|
||||
else:
|
||||
# data from prev stages
|
||||
base = 0
|
||||
for val in input_vals:
|
||||
val_pos = val.get()
|
||||
src_partition_id = val_pos.partition_id
|
||||
src_offset = val_pos.offset
|
||||
src_index = base
|
||||
src_partition = topo.get_partition_by_id(src_partition_id)
|
||||
output_len = len(src_partition.get_output_vals())
|
||||
# data from not-input partition
|
||||
if src_partition_id != model_input_partition_id:
|
||||
src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo)
|
||||
src_index = base
|
||||
for i, stage_id in enumerate(producer_stage_ids):
|
||||
if stage_id == src_stage_id:
|
||||
src_index += i
|
||||
break
|
||||
else: # data from input partition
|
||||
src_index = 0
|
||||
# when output_len = 1, not iterable
|
||||
if target_index == src_index:
|
||||
if output_len == 1:
|
||||
res = None # offset = None to get all outputs
|
||||
return res
|
||||
else:
|
||||
res.append(src_offset)
|
||||
return res
|
||||
|
||||
def _get_output_offsets_by_index(self, target_index):
|
||||
res = []
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||
output_vals = self_partition.get_output_vals()
|
||||
consumer_stage_ids = self.get_consumer_stage_ids()
|
||||
for val_list in output_vals:
|
||||
# An output may be passed to many down stages.
|
||||
target = None
|
||||
for val_pos in val_list.get():
|
||||
dst_partition_id = val_pos.partition_id
|
||||
dst_offset = val_pos.offset
|
||||
dst_partition = topo.get_partition_by_id(dst_partition_id)
|
||||
input_len = len(dst_partition.get_input_vals())
|
||||
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
|
||||
for i, stage_id in enumerate(consumer_stage_ids):
|
||||
if stage_id == dst_stage_id:
|
||||
dst_index = i
|
||||
break
|
||||
if target_index == dst_index:
|
||||
if input_len == 1:
|
||||
res = None # offset = None to get all outputs
|
||||
return res
|
||||
else:
|
||||
res.append(dst_offset)
|
||||
return res
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
def _get_real_args_kwargs_fwd(self, args_or_kwargs):
|
||||
if not self.use_middleware():
|
||||
|
@ -521,8 +619,7 @@ class WorkerBase(ABC):
|
|||
flatten_args = []
|
||||
if self.is_first_stage():
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
# TODO get by offset
|
||||
else:
|
||||
else: # get by offset
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||
|
@ -557,7 +654,9 @@ class WorkerBase(ABC):
|
|||
if output_len == 1:
|
||||
target = args_or_kwargs[src_index]
|
||||
else:
|
||||
target = args_or_kwargs[src_index][src_offset]
|
||||
offsets = self._get_input_offsets_by_index(src_index)
|
||||
real_offset = offsets.index(src_offset)
|
||||
target = args_or_kwargs[src_index][real_offset]
|
||||
flatten_args.append(target)
|
||||
args_or_kwargs = flatten_args
|
||||
return args_or_kwargs
|
||||
|
@ -574,10 +673,10 @@ class WorkerBase(ABC):
|
|||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
args_or_kwargs = flatten_args
|
||||
else:
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
if args_or_kwargs is not None:
|
||||
for i, arg in enumerate(args_or_kwargs):
|
||||
args_or_kwargs[i] = arg.wait()
|
||||
if args_or_kwargs is not None: # get by offset
|
||||
flatten_args = []
|
||||
# TODO get by offset
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||
|
@ -599,7 +698,9 @@ class WorkerBase(ABC):
|
|||
if input_len == 1:
|
||||
part_grad = args_or_kwargs[dst_index]
|
||||
else:
|
||||
part_grad = args_or_kwargs[dst_index][dst_offset]
|
||||
offsets = self._get_output_offsets_by_index(dst_index)
|
||||
real_offsets = offsets.index(dst_offset)
|
||||
part_grad = args_or_kwargs[dst_index][real_offsets]
|
||||
|
||||
if target is None:
|
||||
target = part_grad
|
||||
|
@ -682,10 +783,6 @@ class WorkerBase(ABC):
|
|||
else:
|
||||
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
||||
|
||||
# if not forward_only:
|
||||
# pytree_map(args_kwargs,
|
||||
# lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
|
||||
# process_types=torch.Tensor)
|
||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
|
||||
|
@ -752,14 +849,14 @@ class WorkerBase(ABC):
|
|||
stage_input_kwargs,
|
||||
stage_outputs,
|
||||
checkpoint=use_checkpoint)
|
||||
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in
|
||||
|
||||
# if not forward_only, do the backward
|
||||
if not forward_only:
|
||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||
self._begin_backward(microbatch_id)
|
||||
|
||||
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
|
||||
elif phase == Phase.BACKWARD:
|
||||
# remind its producer to get data before backward
|
||||
if not is_first_stage:
|
||||
|
@ -803,10 +900,8 @@ class WorkerBase(ABC):
|
|||
filtered_grads.append(grad)
|
||||
|
||||
stage_outputs = filtered_outputs
|
||||
grad_tensors = filtered_grads
|
||||
|
||||
grad_tensors = pyobj_map(grad_tensors, fn=lambda x: x.to(self.device),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||
|
||||
# collect grad of input tensor
|
||||
|
@ -941,8 +1036,6 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
|
||||
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
|
||||
|
||||
self.step_futs: List[Future] = []
|
||||
|
||||
self._check_argument()
|
||||
self._create_pp_rank_to_rpc_worker_id()
|
||||
self._create_pp_rank_to_module_partition_id()
|
||||
|
@ -1058,9 +1151,14 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
ret_future[pp_rank][microbatch_id - actual_stage_num].wait()
|
||||
else:
|
||||
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
||||
futs = []
|
||||
for pp_rank in input_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_sync().get_output_by_key(key, ref_use=True)
|
||||
fut = worker_rref.rpc_async().get_output_by_key(key, ref_use=True, offsets=[])
|
||||
futs.append(fut)
|
||||
|
||||
for fut in futs:
|
||||
fut.wait()
|
||||
|
||||
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
||||
num_microbatches = self.num_microbatches
|
||||
|
@ -1087,10 +1185,16 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
|
||||
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
|
||||
if not forward_only:
|
||||
backward_result = []
|
||||
for pp_rank in input_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
||||
worker_rref.rpc_sync().get_output_by_key(key)
|
||||
fut = worker_rref.rpc_async().get_output_by_key(
|
||||
key, offsets=[]) # only ensure the res exists, no need for real data.
|
||||
backward_result.append(fut)
|
||||
|
||||
for fut in backward_result:
|
||||
fut.wait()
|
||||
|
||||
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
||||
forward_result = []
|
||||
|
@ -1109,12 +1213,13 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
|
||||
def _reset_worker(self):
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
reset_futs: List[Future] = []
|
||||
for pp_rank in range(actual_stage_num):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
fut = worker_rref.rpc_async().reset_context()
|
||||
self.step_futs.append(fut)
|
||||
reset_futs.append(fut)
|
||||
|
||||
for fut in self.step_futs:
|
||||
for fut in reset_futs:
|
||||
fut.wait()
|
||||
|
||||
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
||||
|
@ -1141,7 +1246,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
for microbatch_id in range(num_microbatches):
|
||||
# control data input speed
|
||||
# to prevent exceed of wait limitations
|
||||
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
|
||||
# self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
|
||||
batch_start = microbatch_size * microbatch_id
|
||||
batch_end = min(batch_start + microbatch_size, batch_length)
|
||||
|
||||
|
@ -1178,10 +1283,11 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
|
||||
def step(self):
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
step_futs: List[Future] = []
|
||||
for pp_rank in range(actual_stage_num):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
fut = worker_rref.rpc_async().step()
|
||||
self.step_futs.append(fut)
|
||||
step_futs.append(fut)
|
||||
|
||||
for fut in self.step_futs:
|
||||
for fut in step_futs:
|
||||
fut.wait()
|
||||
|
|
Loading…
Reference in New Issue