diff --git a/chunk_codegen.py b/chunk_codegen.py index 684028c01..09fda2b98 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -12,7 +12,7 @@ except: CODEGEN_AVAILABLE = False if CODEGEN_AVAILABLE: - __all__ = ['ActivationCheckpointCodeGen'] + __all__ = ['ChunkCodeGen'] else: __all__ = ['python_code_with_activation_checkpoint'] @@ -375,7 +375,7 @@ def emit_ckpt_func(body, body.append(usage) -def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): +def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use this function to emit the activation checkpoint codes. @@ -392,21 +392,21 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod end_idx = [item[1] for item in ckpt_regions] # find the offload regions - offload_regions, offload_labels = _find_offload_regions(nodes) - offload_starts = [item[0] for item in offload_regions] - offload_ends = [item[1] for item in offload_regions] - offload_inputs = [] - offload_outputs = [] - within_offload_region = False + chunk_regions, chunk_labels = _find_offload_regions(nodes) + chunk_starts = [item[0] for item in chunk_regions] + chunk_ends = [item[1] for item in chunk_regions] + chunk_inputs = [] + chunk_outputs = [] + within_chunk_region = False node_list = list(nodes) # find the input and output var names for each offload region - for idx, (start, end) in enumerate(offload_regions): + for idx, (start, end) in enumerate(chunk_regions): offload_node_list = node_list[start:end + 1] inputs, outputs = _find_input_and_output_nodes(offload_node_list) - offload_inputs.append(inputs) - offload_outputs.append(outputs) + chunk_inputs.append(inputs) + chunk_outputs.append(outputs) # this flag is to prevent repeated insert of save tensors # hooks definition in ckpt_func @@ -427,10 +427,10 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod else: node = node_list[node_idx] - if node_idx in offload_starts: - offload_label = offload_labels[offload_starts.index(node_idx)] - _, offload_input, offload_bar = offload_label - within_offload_region = True + if node_idx in chunk_starts: + chunk_label = chunk_labels[chunk_starts.index(node_idx)] + _, chunk_input, chunk_bar = chunk_label + within_chunk_region = True # insert hook functions if needed if not is_hook_inserted: @@ -438,20 +438,20 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") is_hook_inserted = True - if offload_input and offload_bar: + if chunk_input and chunk_bar: body.append(_gen_save_on_cpu_context()) - elif offload_input: - for par in offload_inputs[offload_label[0]]: + elif chunk_input: + for par in chunk_inputs[chunk_label[0]]: body.append(f"setattr({par}, 'offload', True)\n") body.append(_gen_save_tensors_hooks_context(offload_input=True)) else: - for par in offload_inputs[offload_label[0]]: + for par in chunk_inputs[chunk_label[0]]: body.append(f"setattr({par}, 'offload', False)\n") body.append(_gen_save_tensors_hooks_context(offload_input=False)) - if within_offload_region: + if within_chunk_region: emit_node_func(node, body) body[-1] = ' ' + body[-1] delete_unused_value_func(node, body) @@ -460,150 +460,15 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod emit_node_func(node, body) delete_unused_value_func(node, body) - if node_idx in offload_ends: - within_offload_region = False + if node_idx in chunk_ends: + within_chunk_region = False node_idx += 1 -def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): - # find the activation checkpoint regions - ckpt_regions = _find_ckpt_regions(nodes) - start_idx = [item[0] for item in ckpt_regions] - end_idx = [item[1] for item in ckpt_regions] - input_vars = [] - output_vars = [] - within_ckpt_region = False - - # find the offload regions - offload_regions, offload_labels = _find_offload_regions(nodes) - offload_starts = [item[0] for item in offload_regions] - offload_ends = [item[1] for item in offload_regions] - offload_inputs = [] - offload_outputs = [] - within_offload_region = False - - node_list = list(nodes) - - # use this variable to avoid inserting hook functions - # to ckpt_func repeatedly - is_hook_inserted = False - - # find the input and output var names for each region - for idx, (start, end) in enumerate(ckpt_regions): - ckpt_node_list = node_list[start:end + 1] - inputs, outputs = _find_input_and_output_nodes(ckpt_node_list) - input_vars.append(inputs) - output_vars.append(outputs) - - # find the input and output var names for each offload region - for idx, (start, end) in enumerate(offload_regions): - offload_node_list = node_list[start:end + 1] - inputs, outputs = _find_input_and_output_nodes(offload_node_list) - offload_inputs.append(inputs) - offload_outputs.append(outputs) - - # append code text to body - for idx, node in enumerate(node_list): - # if this is the first node of the ckpt region - # append the ckpt function defition - if idx in start_idx: - label = start_idx.index(idx) - ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) - ckpt_func.append(f'{ckpt_fn_def}\n') - within_ckpt_region = True - - if idx in offload_starts: - offload_label = offload_labels[offload_starts.index(idx)] - _, offload_input, offload_bar = offload_label - within_offload_region = True - - # insert hook functions if needed - if not is_hook_inserted: - pack_hook, unpack_hook = _gen_saved_tensors_hooks() - ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") - is_hook_inserted = True - - if offload_input and offload_bar: - body.append(_gen_save_on_cpu_context()) - - elif offload_input: - for par in offload_inputs[offload_label[0]]: - body.append(f"setattr({par}, 'offload', True)\n") - body.append(_gen_save_tensors_hooks_context(offload_input=True)) - - else: - for par in offload_inputs[offload_label[0]]: - body.append(f"setattr({par}, 'offload', False)\n") - body.append(_gen_save_tensors_hooks_context(offload_input=False)) - - # NOTE: emit_node does not emit a string with newline. It depends - # on delete_unused_values to append one - # NOTE: currently we separate body and ckpt_func definition - if within_ckpt_region: - emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] - delete_unused_value_func(node, ckpt_func) - - elif within_offload_region: - emit_node_func(node, body) - body[-1] = ' ' + body[-1] - delete_unused_value_func(node, body) - - else: - emit_node_func(node, body) - delete_unused_value_func(node, body) - - if idx in end_idx: - # if this is the last node of the ckpt region - # generate return statement - label = end_idx.index(idx) - return_statement = _gen_ckpt_output(output_vars[label]) - return_statement = f' {return_statement}\n\n' - ckpt_func.append(return_statement) - - # we need to check if the checkpoint need to offload the input - start_node_idx = start_idx[label] - if hasattr(node_list[start_node_idx], 'activation_offload'): - activation_offload = node_list[start_node_idx].activation_offload - else: - activation_offload = False - - # we need to check if the checkpoint need use_reentrant=False - use_reentrant = True - non_leaf_input = 0 - for var in input_vars[label]: - input_node = next(item for item in node_list if item.name == var) - if input_node.op != "placeholder": - non_leaf_input = 1 - for user in input_node.users: - if hasattr(user, "activation_checkpoint"): - if user.activation_checkpoint == label: - if user.op == "call_module": - if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): - use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace - - elif user.op == "call_function": - if "inplace" in user.kwargs: - use_reentrant = not user.kwargs["inplace"] - - # if all the inputs are leaf nodes, we need to set use_reentrant = False - if not non_leaf_input: - use_reentrant = False - - # generate checkpoint function call in a new line - usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) - usage += '\n' - body.append(usage) - within_ckpt_region = False - - if idx in offload_ends: - within_offload_region = False - - if CODEGEN_AVAILABLE: - class ActivationCheckpointCodeGen(CodeGen): + class ChunkCodeGen(CodeGen): def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] @@ -796,10 +661,7 @@ if CODEGEN_AVAILABLE: # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes): - emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) - else: - emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -833,215 +695,3 @@ if CODEGEN_AVAILABLE: {prologue} {code}""" return PythonCode(fn_code, globals_) - -else: - - def python_code_with_activation_checkpoint(self, root_module: str, namespace: _Namespace) -> PythonCode: - """ - This method is copied from the _python_code of torch.fx.graph.Graph. Modifications are made so that it can generate - code for activation checkpoint. - """ - free_vars: List[str] = [] - body: List[str] = [] - globals_: Dict[str, Any] = {} - wrapped_fns: Dict[str, None] = {} - - # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] - - def add_global(name_hint: str, obj: Any): - """Add an obj to be tracked as a global. - - We call this for names that reference objects external to the - Graph, like functions or types. - - Returns: the global name that should be used to reference 'obj' in generated source. - """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device - # HACK: workaround for how torch custom ops are registered. We - # can't import them like normal modules so they must retain their - # fully qualified name. - return _get_qualified_name(obj) - - # normalize the name hint to get a proper identifier - global_name = namespace.create_name(name_hint, obj) - - if global_name in globals_: - assert globals_[global_name] is obj - return global_name - globals_[global_name] = obj - return global_name - - # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) - - # Pre-fill the globals table with registered builtins. - for name, (_, obj) in _custom_builtins.items(): - add_global(name, obj) - - def type_repr(o: Any): - if o == (): - # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' - - typename = _type_repr(o) - - # This is a generic type, e.g. typing.List[torch.Tensor] - if hasattr(o, '__origin__'): - origin_type = _origin_type_map.get(o.__origin__, o.__origin__) - origin_typename = add_global(_type_repr(origin_type), origin_type) - - # Assign global names for each of the inner type variables. - args = [type_repr(arg) for arg in o.__args__] - - return f'{origin_typename}[{",".join(args)}]' - - # Common case: this is a regular module name like 'foo.bar.baz' - return add_global(typename, o) - - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(self.nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - # NOTE: we add a variable to distinguish body and ckpt_func - def delete_unused_values(user: Node, body): - """ - Delete values after their last use. This ensures that values that are - not used in the remainder of the code are freed and the memory usage - of the code is optimal. - """ - if user.op == 'placeholder': - return - if user.op == 'output': - body.append('\n') - return - nodes_to_delete = user_to_last_uses.get(user, []) - if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') - else: - body.append('\n') - - # NOTE: we add a variable to distinguish body and ckpt_func - def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': - assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') - if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') - return - elif node.op == 'call_method': - assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') - return - elif node.op == 'call_function': - assert callable(node.target) - # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: - assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') - return - qualified_name = _get_qualified_name(node.target) - global_name = add_global(qualified_name, node.target) - # special case for getattr: node.args could be 2-argument or 3-argument - # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: - body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') - return - body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): - wrapped_fns.setdefault(global_name) - return - elif node.op == 'call_module': - assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') - return - elif node.op == 'get_attr': - assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') - return - elif node.op == 'output': - if node.type is not None: - maybe_return_annotation[0] = f" -> {type_repr(node.type)}" - if self._pytree_info is None: - body.append(f'return {repr(node.args[0])}') - else: - body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)') - return - raise NotImplementedError(f'node: {node.op} {node.target}') - - # Modified for activation checkpointing - ckpt_func = [] - - # if any node has a list of labels for activation_checkpoint, we - # will use nested type of activation checkpoint codegen - if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes): - emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) - else: - emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) - - if len(body) == 0: - # If the Graph has no non-placeholder nodes, no lines for the body - # have been emitted. To continue to have valid Python code, emit a - # single pass statement - body.append('pass\n') - if self._pytree_info is not None: - orig_args = self._pytree_info.orig_args - has_orig_self = (orig_args[0] == 'self') - if has_orig_self: - free_vars.insert(0, 'self') - if len(free_vars) > 0: # pytree has placeholders in it - body.insert( - 0, - f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n") - else: - orig_args = free_vars - - if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) - else: - wrap_stmts = '' - - ckpt_func = ''.join(ckpt_func) - - # If the original function didn't have self as its first argument, we - # would have added it. - if len(orig_args) == 0 or orig_args[0] != 'self': - orig_args.insert(0, 'self') - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) - - # as we need colossalai.utils.checkpoint, we need to import colossalai - # in forward function - fn_code = f""" -{wrap_stmts} - -{ckpt_func} -def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: -{code}""" - return PythonCode(fn_code, globals_) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 9ac399a29..85164bdad 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -11,7 +11,7 @@ from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule try: - from chunk_codegen import ActivationCheckpointCodeGen + from chunk_codegen import ChunkCodeGen with_codegen = True except: # fall back to older pytorch version @@ -75,7 +75,7 @@ def _run_offload_codegen(rank): # trace the module and replace codegen tracer = ColoTracer(trace_act_ckpt=True) graph = tracer.trace(model) - codegen = ActivationCheckpointCodeGen() + codegen = ChunkCodeGen() graph.set_codegen(codegen) # annotate the activation offload part @@ -99,15 +99,7 @@ def _run_offload_codegen(rank): # assert we have all the components code = graph.python_code("self").src - assert "def pack_hook_input(self, x):" in code and \ - "def unpack_hook(self, packed):" in code and \ - "def pack_hook_no_input(self, x):" in code and \ - "setattr(x, 'offload', True)" in code and \ - "setattr(linear3, 'offload', False)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ - "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + print(code) _test_fwd_and_bwd(model, gm, data) gpc.destroy() @@ -118,60 +110,5 @@ def test_act_ckpt_codegen(): mp.spawn(_run_offload_codegen, nprocs=1) -def _run_offload_codegen_torch11(rank): - # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') - - # build model and input - model = MyNet().cuda() - data = torch.rand(4, 4).cuda() - - # trace the module and replace codegen - tracer = ColoTracer(trace_act_ckpt=True) - graph = tracer.trace(model) - - # replace a bound method of an object - graph._python_code = python_code_with_activation_checkpoint.__get__(graph) - - # annotate the activation offload part - # also annotate the activation_checkpoint so we could test both types - # of input offload - for node in graph.nodes: - if node.name == "linear0": - setattr(node, "activation_offload", [0, True, False]) - if node.name == "linear1": - setattr(node, "activation_offload", [0, True, False]) - if node.name == "linear2": - setattr(node, "activation_offload", [1, True, True]) - if node.name == "linear4": - setattr(node, "activation_offload", [2, False, True]) - if node.name == "linear5": - setattr(node, "activation_checkpoint", [0]) - setattr(node, "activation_offload", True) - - gm = ColoGraphModule(copy.deepcopy(model), graph) - gm.recompile() - - # assert we have all the components - code = graph.python_code("self").src - assert "def pack_hook_input(self, x):" in code and \ - "def unpack_hook(self, packed):" in code and \ - "def pack_hook_no_input(self, x):" in code and \ - "setattr(x, 'offload', True)" in code and \ - "setattr(linear3, 'offload', False)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ - "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code - - _test_fwd_and_bwd(model, gm, data) - gpc.destroy() - - -@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented") -def test_act_ckpt_python_code_torch11(): - mp.spawn(_run_offload_codegen_torch11, nprocs=1) - - if __name__ == "__main__": _run_offload_codegen(0)