[autoparallel] Add pofo sequence annotation (#1637)

* [autoparallel] annotate pofo sequence

* [autoparallel] remove unused print

* [autoparallel] fix some code
pull/1643/head
Boyuan Yao 2 years ago committed by GitHub
parent 04bbabeea8
commit f921733621
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -145,7 +145,7 @@ def _find_ckpt_regions(nodes: List[Node]):
def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
tuple in the form of (idx, offload_input, offload_bar). idx indicates the offload
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
region's index, offload_input is a bool type indicates whether we need to offload
the input, offload_bar is a bool type indicates whether we need to offload all the
intermediate x_bars of this region.
@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', False), tuple):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), list):
act_offload_label = node.activation_offload
if current_region == None:

@ -97,6 +97,7 @@ class PofoSolver:
self.bandwidth = bandwidth
self.disc_chain = copy.deepcopy(self.chain)
self.disc_chain._discretize(self.mem_unit)
self.rotor_table = _compute_table(self.disc_chain, mem_slots)
self._compute_pofo_table()
@ -142,7 +143,7 @@ class PofoSolver:
return (max(compute, comm) + compute + comm) / 2
def _rotor_estimated_bwd_sequence(self, i, j, m, delta):
return _rec(self.disc_chain, i, j, math.floor(m - self.chain.cweight[i] / self.mem_unit), self.rotor_table)
return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table)
def _common_values_enable(self, state: Tuple):
@ -354,6 +355,129 @@ class PofoSolver:
return result
def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]):
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
# forward annotation
for op in fwd_list:
if in_ckpt:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
setattr(n, "activation_checkpoint", [ckpt_idx])
ckpt_idx += 1
ckpt_region = [op.index]
else:
if isinstance(op, ForwardCheck):
in_ckpt = True
ckpt_region.append(op.index)
# annotate the backward if there is any nested activation checkpoint
in_recompute = False
for op in bwd_list:
if in_recompute:
if isinstance(op, ForwardNograd):
ckpt_region.append(op.index)
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.activation_checkpoint.append(ckpt_idx)
in_recompute = False
else:
if not isinstance(op, Backward):
in_recompute = True
ckpt_idx = 0
ckpt_region = []
if isinstance(op, ForwardCheck):
ckpt_region.append(op.index)
# postprocess, make sure every activation checkpoint label in the
# same activation checkpoint region (level = 0) has the same length
op_list = []
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
for idx in range(start_idx, end_idx + 1):
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
# annotate the offload
offload_idx = 0
for idx, op in enumerate(fwd_list):
if isinstance(op, Offload):
# corner case: offload input
if op.index == 0:
if isinstance(fwd_list[idx + 1], ForwardCheck):
for n in node_list[op.index]:
setattr(n, "activation_offload", True)
else:
for n in node_list[op.index]:
setattr(n, "activation_offload", (offload_idx, True, False))
offload_idx += 1
else:
if op.has_bar:
# annotate previous node
if hasattr(node_list[op.index - 1][0], "activation_offload"):
for n in node_list[op.index - 1]:
n.activation_offload[-1] = True
else:
for n in node_list[op.index - 1]:
setattr(n, "activation_offload", [offload_idx, False, True])
offload_idx += 1
# annotate this node
if isinstance(fwd_list[idx + 1], ForwardCheck):
for n in node_list[op.index]:
setattr(n, "activation_offload", True)
else:
for n in node_list[op.index]:
setattr(n, "activation_offload", [offload_idx, True, False])
offload_idx += 1
def solver_pofo(gm: ColoGraphModule,
data,
bandwidth,
@ -398,7 +522,8 @@ def solver_pofo(gm: ColoGraphModule,
first_state = (0, 0, 0, 0, False)
sequence = solver.pofo_rec(first_state)
if sequence == None:
print(f"Can not solve strategy with {mem_limit / 1024**2} MB memory!")
raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory")
_annotate_from_pofo_sequence(sequence, node_list)
setattr(gm, "__sequence__", sequence)
return gm

@ -54,7 +54,8 @@ class Offload(Operation):
super().__init__()
self.index = index
self.name = "Off"
if has_bar:
self.has_bar = has_bar
if self.has_bar:
self.name += "wBar"
def __repr__(self):
@ -67,7 +68,8 @@ class Prefetch(Operation):
super().__init__()
self.index = index
self.name = "Pre"
if has_bar:
self.has_bar = has_bar
if self.has_bar:
self.name += "wBar"
def __repr__(self):

Loading…
Cancel
Save