mirror of https://github.com/hpcaitech/ColossalAI
polish code
parent
f379d1a94d
commit
7e2bd1e428
258
chunk_codegen.py
258
chunk_codegen.py
|
@ -3,20 +3,11 @@ import torch
|
|||
import copy
|
||||
from typing import List, Callable, Any, Tuple, Dict, Iterable
|
||||
|
||||
try:
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
|
||||
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size
|
||||
CODEGEN_AVAILABLE = True
|
||||
except:
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
CODEGEN_AVAILABLE = False
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
__all__ = ['ChunkCodeGen']
|
||||
else:
|
||||
__all__ = ['python_code_with_activation_checkpoint']
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
|
||||
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size
|
||||
CODEGEN_AVAILABLE = True
|
||||
__all__ = ['ChunkCodeGen']
|
||||
|
||||
|
||||
class NodeIndexTracer(object):
|
||||
|
@ -289,9 +280,9 @@ class NodeIndexTracer(object):
|
|||
2. compute the real value of -1 in target shape.
|
||||
3. determine changed dim, and assgin index for generated dim.
|
||||
4. log changed dim and generated dim for restore
|
||||
5. look into view list to see whether the view is associated with other,
|
||||
5. inherit computation.
|
||||
6. TODO: look into view list to see whether the view is associated with other,
|
||||
if so assgin equal dim according to previous view.
|
||||
6. inherit computation.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
|
@ -352,7 +343,7 @@ class NodeIndexTracer(object):
|
|||
self.mark_computation(node, node_idx, [j])
|
||||
break
|
||||
|
||||
# log view
|
||||
# log view, not used now
|
||||
view_dict = {"idx_from": [origin_trace[i] for i in dim_from],
|
||||
"dim_from": dim_from,
|
||||
"idx_to": [new_trace[i] for i in dim_to],
|
||||
|
@ -680,239 +671,6 @@ def _find_idx_by_name(name, nodes_list):
|
|||
if node.name == name:
|
||||
return idx
|
||||
raise RuntimeError("name %s not found in node list" % name)
|
||||
|
||||
|
||||
def _find_offload_regions(nodes: List[Node]):
|
||||
"""This function is to find the offload regions
|
||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
|
||||
region's index, offload_input is a bool type indicates whether we need to offload
|
||||
the input, offload_bar is a bool type indicates whether we need to offload all the
|
||||
intermediate x_bars of this region.
|
||||
"""
|
||||
offload_regions = []
|
||||
offload_labels = []
|
||||
start = -1
|
||||
end = -1
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable):
|
||||
act_offload_label = node.activation_offload
|
||||
|
||||
if current_region == None:
|
||||
current_region = act_offload_label
|
||||
start = idx
|
||||
offload_labels.append(act_offload_label)
|
||||
|
||||
if act_offload_label != current_region:
|
||||
assert start != -1
|
||||
offload_regions.append((start, idx - 1))
|
||||
offload_labels.append(act_offload_label)
|
||||
current_region = act_offload_label
|
||||
start = idx
|
||||
end = -1
|
||||
|
||||
else:
|
||||
if current_region is not None:
|
||||
end = idx - 1
|
||||
assert start != -1 and end != -1
|
||||
offload_regions.append((start, end))
|
||||
start = end = -1
|
||||
current_region = None
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
return offload_regions, offload_labels
|
||||
|
||||
|
||||
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
|
||||
"""
|
||||
Generate the checkpoint function definition
|
||||
"""
|
||||
return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):"
|
||||
|
||||
|
||||
def _gen_ckpt_output(output_vars: List[str]) -> str:
|
||||
"""
|
||||
Generate the return statement for checkpoint region
|
||||
"""
|
||||
return f"return {', '.join(output_vars)}"
|
||||
|
||||
|
||||
def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True):
|
||||
"""
|
||||
Generate the checkpoint function call code text
|
||||
"""
|
||||
outputs = ', '.join(output_vars)
|
||||
inputs = ', '.join(input_vars)
|
||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
|
||||
|
||||
|
||||
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
|
||||
"""Check if the node could end the ckpt region
|
||||
|
||||
Args:
|
||||
node (Node): torch.fx.Node
|
||||
check_idx (int): the index of checkpoint level for
|
||||
nested checkpoint
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
if hasattr(node, "activation_checkpoint"):
|
||||
if isinstance(node.activation_checkpoint, list):
|
||||
return node.activation_checkpoint[check_idx] == None
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _find_nested_ckpt_regions(nodes, check_idx=0):
|
||||
"""
|
||||
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
|
||||
will be list of tuples, each tuple is in the form of (start_index, end_index).
|
||||
"""
|
||||
ckpt_regions = []
|
||||
start = -1
|
||||
end = -1
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_checkpoint'):
|
||||
if isinstance(getattr(node, 'activation_checkpoint'), int):
|
||||
act_ckpt_label = node.activation_checkpoint
|
||||
else:
|
||||
act_ckpt_label = node.activation_checkpoint[check_idx]
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
if current_region is None:
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
|
||||
# if activation checkpoint has changed
|
||||
# we restart the tracking
|
||||
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
|
||||
if act_ckpt_label != current_region:
|
||||
assert start != -1
|
||||
ckpt_regions.append((start, idx - 1))
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
end = -1
|
||||
elif current_region is not None and _end_of_ckpt(node, check_idx):
|
||||
# used to check the case below
|
||||
# node ckpt states = [ckpt, ckpt, non-ckpt]
|
||||
end = idx - 1
|
||||
assert start != -1 and end != -1
|
||||
ckpt_regions.append((start, end))
|
||||
start = end = -1
|
||||
current_region = None
|
||||
else:
|
||||
pass
|
||||
|
||||
if current_region is not None:
|
||||
end = len(nodes) - 1
|
||||
ckpt_regions.append((start, end))
|
||||
return ckpt_regions
|
||||
|
||||
|
||||
def emit_ckpt_func(body,
|
||||
ckpt_func,
|
||||
node_list: List[Node],
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
level=0,
|
||||
in_ckpt=False):
|
||||
"""Emit ckpt fuction in nested way
|
||||
|
||||
Args:
|
||||
body: forward code, in recursive calls, this part will be checkpoint
|
||||
functions code
|
||||
ckpt_func: checkpoint functions code, in recursive calls, this part
|
||||
will be a buffer
|
||||
node_list (List[Node]): list of torch.fx.Node
|
||||
emit_node_func: function to emit a node
|
||||
delete_unused_value_func: function to delete unused value
|
||||
level (int, optional): checkpoint level. Defaults to 0.
|
||||
in_ckpt (bool, optional): indicates wether the func is in recursive
|
||||
call. Defaults to False.
|
||||
"""
|
||||
inputs, outputs = _find_input_and_output_nodes(node_list)
|
||||
|
||||
# if the current checkpoint function use int as label, using old generation method
|
||||
if isinstance(node_list[0].activation_checkpoint, int):
|
||||
label = node_list[0].activation_checkpoint
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
for node in node_list:
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
|
||||
usage += "\n"
|
||||
body.append(usage)
|
||||
|
||||
# use nested ckpt function codegen
|
||||
else:
|
||||
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
|
||||
# the label will be '0_1_1'
|
||||
label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]])
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
|
||||
# if there is more level to fetch
|
||||
if level + 1 < len(node_list[0].activation_checkpoint):
|
||||
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
# use ckpt_func_buffer to store nested checkpoint functions
|
||||
ckpt_func_buffer = []
|
||||
node_idx = 0
|
||||
while 1:
|
||||
if node_idx >= len(node_list):
|
||||
break
|
||||
|
||||
if node_idx in start_idx:
|
||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func,
|
||||
delete_unused_value_func, level + 1, True)
|
||||
node_idx += len(ckpt_node_list)
|
||||
|
||||
else:
|
||||
node = node_list[node_idx]
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
node_idx += 1
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
ckpt_func += ckpt_func_buffer
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
body.append(usage)
|
||||
|
||||
# last level
|
||||
else:
|
||||
for node in node_list:
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
activation_offload = getattr(node_list[0], "activation_offload", False)
|
||||
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
body.append(usage)
|
||||
|
||||
|
||||
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):
|
||||
|
|
Loading…
Reference in New Issue