diff --git a/chunk_codegen.py b/chunk_codegen.py index a14f7c134..9930a0570 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -3,20 +3,11 @@ import torch import copy from typing import List, Callable, Any, Tuple, Dict, Iterable -try: - from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name - from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin - from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size - CODEGEN_AVAILABLE = True -except: - from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin - from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name - CODEGEN_AVAILABLE = False - -if CODEGEN_AVAILABLE: - __all__ = ['ChunkCodeGen'] -else: - __all__ = ['python_code_with_activation_checkpoint'] +from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name +from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin +from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size +CODEGEN_AVAILABLE = True +__all__ = ['ChunkCodeGen'] class NodeIndexTracer(object): @@ -289,9 +280,9 @@ class NodeIndexTracer(object): 2. compute the real value of -1 in target shape. 3. determine changed dim, and assgin index for generated dim. 4. log changed dim and generated dim for restore - 5. look into view list to see whether the view is associated with other, + 5. inherit computation. + 6. TODO: look into view list to see whether the view is associated with other, if so assgin equal dim according to previous view. - 6. inherit computation. Args: node (node) @@ -352,7 +343,7 @@ class NodeIndexTracer(object): self.mark_computation(node, node_idx, [j]) break - # log view + # log view, not used now view_dict = {"idx_from": [origin_trace[i] for i in dim_from], "dim_from": dim_from, "idx_to": [new_trace[i] for i in dim_to], @@ -680,239 +671,6 @@ def _find_idx_by_name(name, nodes_list): if node.name == name: return idx raise RuntimeError("name %s not found in node list" % name) - - -def _find_offload_regions(nodes: List[Node]): - """This function is to find the offload regions - In pofo algorithm, during annotation, we will annotate the offload region with the - list in the form of [idx, offload_input, offload_bar]. idx indicates the offload - region's index, offload_input is a bool type indicates whether we need to offload - the input, offload_bar is a bool type indicates whether we need to offload all the - intermediate x_bars of this region. - """ - offload_regions = [] - offload_labels = [] - start = -1 - end = -1 - current_region = None - - for idx, node in enumerate(nodes): - if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable): - act_offload_label = node.activation_offload - - if current_region == None: - current_region = act_offload_label - start = idx - offload_labels.append(act_offload_label) - - if act_offload_label != current_region: - assert start != -1 - offload_regions.append((start, idx - 1)) - offload_labels.append(act_offload_label) - current_region = act_offload_label - start = idx - end = -1 - - else: - if current_region is not None: - end = idx - 1 - assert start != -1 and end != -1 - offload_regions.append((start, end)) - start = end = -1 - current_region = None - - else: - pass - - return offload_regions, offload_labels - - -def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: - """ - Generate the checkpoint function definition - """ - return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):" - - -def _gen_ckpt_output(output_vars: List[str]) -> str: - """ - Generate the return statement for checkpoint region - """ - return f"return {', '.join(output_vars)}" - - -def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True): - """ - Generate the checkpoint function call code text - """ - outputs = ', '.join(output_vars) - inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' - - -def _end_of_ckpt(node: Node, check_idx: int) -> bool: - """Check if the node could end the ckpt region - - Args: - node (Node): torch.fx.Node - check_idx (int): the index of checkpoint level for - nested checkpoint - - Returns: - bool - """ - if hasattr(node, "activation_checkpoint"): - if isinstance(node.activation_checkpoint, list): - return node.activation_checkpoint[check_idx] == None - else: - return False - else: - return True - - -def _find_nested_ckpt_regions(nodes, check_idx=0): - """ - Find the nested checkpoint regions given a list of consecutive nodes. The outputs - will be list of tuples, each tuple is in the form of (start_index, end_index). - """ - ckpt_regions = [] - start = -1 - end = -1 - current_region = None - - for idx, node in enumerate(nodes): - if hasattr(node, 'activation_checkpoint'): - if isinstance(getattr(node, 'activation_checkpoint'), int): - act_ckpt_label = node.activation_checkpoint - else: - act_ckpt_label = node.activation_checkpoint[check_idx] - - # this activation checkpoint label is not set yet - # meaning this is the first node of the activation ckpt region - if current_region is None: - current_region = act_ckpt_label - start = idx - - # if activation checkpoint has changed - # we restart the tracking - # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] - if act_ckpt_label != current_region: - assert start != -1 - ckpt_regions.append((start, idx - 1)) - current_region = act_ckpt_label - start = idx - end = -1 - elif current_region is not None and _end_of_ckpt(node, check_idx): - # used to check the case below - # node ckpt states = [ckpt, ckpt, non-ckpt] - end = idx - 1 - assert start != -1 and end != -1 - ckpt_regions.append((start, end)) - start = end = -1 - current_region = None - else: - pass - - if current_region is not None: - end = len(nodes) - 1 - ckpt_regions.append((start, end)) - return ckpt_regions - - -def emit_ckpt_func(body, - ckpt_func, - node_list: List[Node], - emit_node_func, - delete_unused_value_func, - level=0, - in_ckpt=False): - """Emit ckpt fuction in nested way - - Args: - body: forward code, in recursive calls, this part will be checkpoint - functions code - ckpt_func: checkpoint functions code, in recursive calls, this part - will be a buffer - node_list (List[Node]): list of torch.fx.Node - emit_node_func: function to emit a node - delete_unused_value_func: function to delete unused value - level (int, optional): checkpoint level. Defaults to 0. - in_ckpt (bool, optional): indicates wether the func is in recursive - call. Defaults to False. - """ - inputs, outputs = _find_input_and_output_nodes(node_list) - - # if the current checkpoint function use int as label, using old generation method - if isinstance(node_list[0].activation_checkpoint, int): - label = node_list[0].activation_checkpoint - ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') - for node in node_list: - emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] - delete_unused_value_func(node, ckpt_func) - - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = getattr(node_list[0], "activation_offload", False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) - usage += "\n" - body.append(usage) - - # use nested ckpt function codegen - else: - # label given by each layer, e.g. if you are currently at level [0, 1, 1] - # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]]) - ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') - - # if there is more level to fetch - if level + 1 < len(node_list[0].activation_checkpoint): - ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) - start_idx = [item[0] for item in ckpt_regions] - end_idx = [item[1] for item in ckpt_regions] - - # use ckpt_func_buffer to store nested checkpoint functions - ckpt_func_buffer = [] - node_idx = 0 - while 1: - if node_idx >= len(node_list): - break - - if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] - emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, - delete_unused_value_func, level + 1, True) - node_idx += len(ckpt_node_list) - - else: - node = node_list[node_idx] - emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] - delete_unused_value_func(node, ckpt_func) - node_idx += 1 - - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - ckpt_func += ckpt_func_buffer - activation_offload = getattr(node_list[0], "activation_offload", False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' - if in_ckpt: - usage = ' ' + usage - body.append(usage) - - # last level - else: - for node in node_list: - emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] - delete_unused_value_func(node, ckpt_func) - - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = getattr(node_list[0], "activation_offload", False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' - if in_ckpt: - usage = ' ' + usage - body.append(usage) def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):