mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] Add pofo sequence annotation (#1637)
* [autoparallel] annotate pofo sequence * [autoparallel] remove unused print * [autoparallel] fix some codepull/1643/head
parent
04bbabeea8
commit
f921733621
|
@ -145,7 +145,7 @@ def _find_ckpt_regions(nodes: List[Node]):
|
|||
def _find_offload_regions(nodes: List[Node]):
|
||||
"""This function is to find the offload regions
|
||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||
tuple in the form of (idx, offload_input, offload_bar). idx indicates the offload
|
||||
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
|
||||
region's index, offload_input is a bool type indicates whether we need to offload
|
||||
the input, offload_bar is a bool type indicates whether we need to offload all the
|
||||
intermediate x_bars of this region.
|
||||
|
@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]):
|
|||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', False), tuple):
|
||||
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), list):
|
||||
act_offload_label = node.activation_offload
|
||||
|
||||
if current_region == None:
|
||||
|
|
|
@ -97,6 +97,7 @@ class PofoSolver:
|
|||
self.bandwidth = bandwidth
|
||||
|
||||
self.disc_chain = copy.deepcopy(self.chain)
|
||||
self.disc_chain._discretize(self.mem_unit)
|
||||
|
||||
self.rotor_table = _compute_table(self.disc_chain, mem_slots)
|
||||
self._compute_pofo_table()
|
||||
|
@ -142,7 +143,7 @@ class PofoSolver:
|
|||
return (max(compute, comm) + compute + comm) / 2
|
||||
|
||||
def _rotor_estimated_bwd_sequence(self, i, j, m, delta):
|
||||
return _rec(self.disc_chain, i, j, math.floor(m - self.chain.cweight[i] / self.mem_unit), self.rotor_table)
|
||||
return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table)
|
||||
|
||||
def _common_values_enable(self, state: Tuple):
|
||||
|
||||
|
@ -354,6 +355,129 @@ class PofoSolver:
|
|||
return result
|
||||
|
||||
|
||||
def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||
op_list = sequence.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
fwd_list = op_list[:op_list.index(loss_op)]
|
||||
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
||||
ckpt_idx = 0
|
||||
in_ckpt = False
|
||||
ckpt_region = []
|
||||
|
||||
# forward annotation
|
||||
for op in fwd_list:
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
in_ckpt = False
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
setattr(n, "activation_checkpoint", [ckpt_idx])
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [op.index]
|
||||
|
||||
else:
|
||||
if isinstance(op, ForwardCheck):
|
||||
in_ckpt = True
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
# annotate the backward if there is any nested activation checkpoint
|
||||
in_recompute = False
|
||||
for op in bwd_list:
|
||||
if in_recompute:
|
||||
if isinstance(op, ForwardNograd):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
elif isinstance(op, ForwardEnable):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = []
|
||||
|
||||
elif isinstance(op, ForwardCheck):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
ckpt_idx += 1
|
||||
ckpt_region = [op.index]
|
||||
|
||||
elif isinstance(op, Backward):
|
||||
for node_idx in ckpt_region:
|
||||
for n in node_list[node_idx]:
|
||||
n.activation_checkpoint.append(ckpt_idx)
|
||||
|
||||
in_recompute = False
|
||||
|
||||
else:
|
||||
if not isinstance(op, Backward):
|
||||
in_recompute = True
|
||||
ckpt_idx = 0
|
||||
ckpt_region = []
|
||||
if isinstance(op, ForwardCheck):
|
||||
ckpt_region.append(op.index)
|
||||
|
||||
# postprocess, make sure every activation checkpoint label in the
|
||||
# same activation checkpoint region (level = 0) has the same length
|
||||
op_list = []
|
||||
for node in node_list:
|
||||
op_list += node
|
||||
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
||||
for (start_idx, end_idx) in ckpt_regions:
|
||||
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
|
||||
|
||||
# annotate the offload
|
||||
offload_idx = 0
|
||||
for idx, op in enumerate(fwd_list):
|
||||
if isinstance(op, Offload):
|
||||
# corner case: offload input
|
||||
if op.index == 0:
|
||||
if isinstance(fwd_list[idx + 1], ForwardCheck):
|
||||
for n in node_list[op.index]:
|
||||
setattr(n, "activation_offload", True)
|
||||
else:
|
||||
for n in node_list[op.index]:
|
||||
setattr(n, "activation_offload", (offload_idx, True, False))
|
||||
offload_idx += 1
|
||||
|
||||
else:
|
||||
if op.has_bar:
|
||||
# annotate previous node
|
||||
if hasattr(node_list[op.index - 1][0], "activation_offload"):
|
||||
for n in node_list[op.index - 1]:
|
||||
n.activation_offload[-1] = True
|
||||
else:
|
||||
for n in node_list[op.index - 1]:
|
||||
setattr(n, "activation_offload", [offload_idx, False, True])
|
||||
|
||||
offload_idx += 1
|
||||
|
||||
# annotate this node
|
||||
if isinstance(fwd_list[idx + 1], ForwardCheck):
|
||||
for n in node_list[op.index]:
|
||||
setattr(n, "activation_offload", True)
|
||||
else:
|
||||
for n in node_list[op.index]:
|
||||
setattr(n, "activation_offload", [offload_idx, True, False])
|
||||
|
||||
offload_idx += 1
|
||||
|
||||
|
||||
def solver_pofo(gm: ColoGraphModule,
|
||||
data,
|
||||
bandwidth,
|
||||
|
@ -398,7 +522,8 @@ def solver_pofo(gm: ColoGraphModule,
|
|||
first_state = (0, 0, 0, 0, False)
|
||||
sequence = solver.pofo_rec(first_state)
|
||||
if sequence == None:
|
||||
print(f"Can not solve strategy with {mem_limit / 1024**2} MB memory!")
|
||||
raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory")
|
||||
|
||||
_annotate_from_pofo_sequence(sequence, node_list)
|
||||
setattr(gm, "__sequence__", sequence)
|
||||
return gm
|
||||
|
|
|
@ -54,7 +54,8 @@ class Offload(Operation):
|
|||
super().__init__()
|
||||
self.index = index
|
||||
self.name = "Off"
|
||||
if has_bar:
|
||||
self.has_bar = has_bar
|
||||
if self.has_bar:
|
||||
self.name += "wBar"
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -67,7 +68,8 @@ class Prefetch(Operation):
|
|||
super().__init__()
|
||||
self.index = index
|
||||
self.name = "Pre"
|
||||
if has_bar:
|
||||
self.has_bar = has_bar
|
||||
if self.has_bar:
|
||||
self.name += "wBar"
|
||||
|
||||
def __repr__(self):
|
||||
|
|
Loading…
Reference in New Issue