|
|
|
@ -97,6 +97,7 @@ class PofoSolver:
|
|
|
|
|
self.bandwidth = bandwidth
|
|
|
|
|
|
|
|
|
|
self.disc_chain = copy.deepcopy(self.chain)
|
|
|
|
|
self.disc_chain._discretize(self.mem_unit)
|
|
|
|
|
|
|
|
|
|
self.rotor_table = _compute_table(self.disc_chain, mem_slots)
|
|
|
|
|
self._compute_pofo_table()
|
|
|
|
@ -142,7 +143,7 @@ class PofoSolver:
|
|
|
|
|
return (max(compute, comm) + compute + comm) / 2
|
|
|
|
|
|
|
|
|
|
def _rotor_estimated_bwd_sequence(self, i, j, m, delta):
|
|
|
|
|
return _rec(self.disc_chain, i, j, math.floor(m - self.chain.cweight[i] / self.mem_unit), self.rotor_table)
|
|
|
|
|
return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table)
|
|
|
|
|
|
|
|
|
|
def _common_values_enable(self, state: Tuple):
|
|
|
|
|
|
|
|
|
@ -354,6 +355,129 @@ class PofoSolver:
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
|
|
|
|
op_list = sequence.list_operations()
|
|
|
|
|
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
|
|
|
|
fwd_list = op_list[:op_list.index(loss_op)]
|
|
|
|
|
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
|
|
|
|
ckpt_idx = 0
|
|
|
|
|
in_ckpt = False
|
|
|
|
|
ckpt_region = []
|
|
|
|
|
|
|
|
|
|
# forward annotation
|
|
|
|
|
for op in fwd_list:
|
|
|
|
|
if in_ckpt:
|
|
|
|
|
if isinstance(op, ForwardNograd):
|
|
|
|
|
ckpt_region.append(op.index)
|
|
|
|
|
|
|
|
|
|
elif isinstance(op, ForwardEnable):
|
|
|
|
|
in_ckpt = False
|
|
|
|
|
for node_idx in ckpt_region:
|
|
|
|
|
for n in node_list[node_idx]:
|
|
|
|
|
setattr(n, "activation_checkpoint", [ckpt_idx])
|
|
|
|
|
|
|
|
|
|
ckpt_idx += 1
|
|
|
|
|
ckpt_region = []
|
|
|
|
|
|
|
|
|
|
elif isinstance(op, ForwardCheck):
|
|
|
|
|
for node_idx in ckpt_region:
|
|
|
|
|
for n in node_list[node_idx]:
|
|
|
|
|
setattr(n, "activation_checkpoint", [ckpt_idx])
|
|
|
|
|
|
|
|
|
|
ckpt_idx += 1
|
|
|
|
|
ckpt_region = [op.index]
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
if isinstance(op, ForwardCheck):
|
|
|
|
|
in_ckpt = True
|
|
|
|
|
ckpt_region.append(op.index)
|
|
|
|
|
|
|
|
|
|
# annotate the backward if there is any nested activation checkpoint
|
|
|
|
|
in_recompute = False
|
|
|
|
|
for op in bwd_list:
|
|
|
|
|
if in_recompute:
|
|
|
|
|
if isinstance(op, ForwardNograd):
|
|
|
|
|
ckpt_region.append(op.index)
|
|
|
|
|
|
|
|
|
|
elif isinstance(op, ForwardEnable):
|
|
|
|
|
for node_idx in ckpt_region:
|
|
|
|
|
for n in node_list[node_idx]:
|
|
|
|
|
n.activation_checkpoint.append(ckpt_idx)
|
|
|
|
|
|
|
|
|
|
ckpt_idx += 1
|
|
|
|
|
ckpt_region = []
|
|
|
|
|
|
|
|
|
|
elif isinstance(op, ForwardCheck):
|
|
|
|
|
for node_idx in ckpt_region:
|
|
|
|
|
for n in node_list[node_idx]:
|
|
|
|
|
n.activation_checkpoint.append(ckpt_idx)
|
|
|
|
|
|
|
|
|
|
ckpt_idx += 1
|
|
|
|
|
ckpt_region = [op.index]
|
|
|
|
|
|
|
|
|
|
elif isinstance(op, Backward):
|
|
|
|
|
for node_idx in ckpt_region:
|
|
|
|
|
for n in node_list[node_idx]:
|
|
|
|
|
n.activation_checkpoint.append(ckpt_idx)
|
|
|
|
|
|
|
|
|
|
in_recompute = False
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
if not isinstance(op, Backward):
|
|
|
|
|
in_recompute = True
|
|
|
|
|
ckpt_idx = 0
|
|
|
|
|
ckpt_region = []
|
|
|
|
|
if isinstance(op, ForwardCheck):
|
|
|
|
|
ckpt_region.append(op.index)
|
|
|
|
|
|
|
|
|
|
# postprocess, make sure every activation checkpoint label in the
|
|
|
|
|
# same activation checkpoint region (level = 0) has the same length
|
|
|
|
|
op_list = []
|
|
|
|
|
for node in node_list:
|
|
|
|
|
op_list += node
|
|
|
|
|
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
|
|
|
|
for (start_idx, end_idx) in ckpt_regions:
|
|
|
|
|
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
|
|
|
|
|
for idx in range(start_idx, end_idx + 1):
|
|
|
|
|
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
|
|
|
|
|
|
|
|
|
|
# annotate the offload
|
|
|
|
|
offload_idx = 0
|
|
|
|
|
for idx, op in enumerate(fwd_list):
|
|
|
|
|
if isinstance(op, Offload):
|
|
|
|
|
# corner case: offload input
|
|
|
|
|
if op.index == 0:
|
|
|
|
|
if isinstance(fwd_list[idx + 1], ForwardCheck):
|
|
|
|
|
for n in node_list[op.index]:
|
|
|
|
|
setattr(n, "activation_offload", True)
|
|
|
|
|
else:
|
|
|
|
|
for n in node_list[op.index]:
|
|
|
|
|
setattr(n, "activation_offload", (offload_idx, True, False))
|
|
|
|
|
offload_idx += 1
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
if op.has_bar:
|
|
|
|
|
# annotate previous node
|
|
|
|
|
if hasattr(node_list[op.index - 1][0], "activation_offload"):
|
|
|
|
|
for n in node_list[op.index - 1]:
|
|
|
|
|
n.activation_offload[-1] = True
|
|
|
|
|
else:
|
|
|
|
|
for n in node_list[op.index - 1]:
|
|
|
|
|
setattr(n, "activation_offload", [offload_idx, False, True])
|
|
|
|
|
|
|
|
|
|
offload_idx += 1
|
|
|
|
|
|
|
|
|
|
# annotate this node
|
|
|
|
|
if isinstance(fwd_list[idx + 1], ForwardCheck):
|
|
|
|
|
for n in node_list[op.index]:
|
|
|
|
|
setattr(n, "activation_offload", True)
|
|
|
|
|
else:
|
|
|
|
|
for n in node_list[op.index]:
|
|
|
|
|
setattr(n, "activation_offload", [offload_idx, True, False])
|
|
|
|
|
|
|
|
|
|
offload_idx += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def solver_pofo(gm: ColoGraphModule,
|
|
|
|
|
data,
|
|
|
|
|
bandwidth,
|
|
|
|
@ -398,7 +522,8 @@ def solver_pofo(gm: ColoGraphModule,
|
|
|
|
|
first_state = (0, 0, 0, 0, False)
|
|
|
|
|
sequence = solver.pofo_rec(first_state)
|
|
|
|
|
if sequence == None:
|
|
|
|
|
print(f"Can not solve strategy with {mem_limit / 1024**2} MB memory!")
|
|
|
|
|
raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory")
|
|
|
|
|
|
|
|
|
|
_annotate_from_pofo_sequence(sequence, node_list)
|
|
|
|
|
setattr(gm, "__sequence__", sequence)
|
|
|
|
|
return gm
|
|
|
|
|