diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 391d64405..e622b8d3a 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -109,6 +109,209 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' +def _end_of_ckpt(node: Node, check_idx: int) -> bool: + """Check if the node could end the ckpt region + + Args: + node (Node): torch.fx.Node + check_idx (int): the index of checkpoint level for + nested checkpoint + + Returns: + bool + """ + if hasattr(node, "activation_checkpoint"): + if isinstance(node.activation_checkpoint, list): + return node.activation_checkpoint[check_idx] == None + else: + return False + else: + return True + + +def _find_nested_ckpt_regions(nodes, check_idx=0): + """ + Find the nested checkpoint regions given a list of consecutive nodes. The outputs + will be list of tuples, each tuple is in the form of (start_index, end_index). + """ + ckpt_regions = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if hasattr(node, 'activation_checkpoint'): + if isinstance(getattr(node, 'activation_checkpoint'), int): + act_ckpt_label = node.activation_checkpoint + else: + act_ckpt_label = node.activation_checkpoint[check_idx] + + # this activation checkpoint label is not set yet + # meaning this is the first node of the activation ckpt region + if current_region is None: + current_region = act_ckpt_label + start = idx + + # if activation checkpoint has changed + # we restart the tracking + # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] + if act_ckpt_label != current_region: + assert start != -1 + ckpt_regions.append((start, idx - 1)) + current_region = act_ckpt_label + start = idx + end = -1 + elif current_region is not None and _end_of_ckpt(node, check_idx): + # used to check the case below + # node ckpt states = [ckpt, ckpt, non-ckpt] + end = idx - 1 + assert start != -1 and end != -1 + ckpt_regions.append((start, end)) + start = end = -1 + current_region = None + else: + pass + + if current_region is not None: + end = len(nodes) - 1 + ckpt_regions.append((start, end)) + return ckpt_regions + + +def emit_ckpt_func(body, + ckpt_func, + node_list: List[Node], + emit_node_func, + delete_unused_value_func, + level=0, + in_ckpt=False): + """Emit ckpt fuction in nested way + + Args: + body: forward code, in recursive calls, this part will be checkpoint + functions code + ckpt_func: checkpoint functions code, in recursive calls, this part + will be a buffer + node_list (List[Node]): list of torch.fx.Node + emit_node_func: function to emit a node + delete_unused_value_func: function to delete unused value + level (int, optional): checkpoint level. Defaults to 0. + in_ckpt (bool, optional): indicates wether the func is in recursive + call. Defaults to False. + """ + inputs, outputs = _find_input_and_output_nodes(node_list) + + # if the current checkpoint function use int as label, using old generation method + if isinstance(node_list[0].activation_checkpoint, int): + label = node_list[0].activation_checkpoint + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n') + activation_offload = getattr(node_list[0], "activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + usage += "\n" + body.append(usage) + + # use nested ckpt function codegen + else: + # label given by each layer, e.g. if you are currently at level [0, 1, 1] + # the label will be '0_1_1' + label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]]) + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + + # if there is more level to fetch + if level + 1 < len(node_list[0].activation_checkpoint): + ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + # use ckpt_func_buffer to store nested checkpoint functions + ckpt_func_buffer = [] + node_idx = 0 + while 1: + if node_idx >= len(node_list): + break + + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, + delete_unused_value_func, level + 1, True) + node_idx += len(ckpt_node_list) + + else: + node = node_list[node_idx] + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + node_idx += 1 + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n') + ckpt_func += ckpt_func_buffer + activation_offload = getattr(node_list[0], "activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + # last level + else: + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n') + activation_offload = getattr(node_list[0], "activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + +def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): + """Emit code with nested activation checkpoint + When we detect some of the node.activation_checkpoint is a List, we will use + this function to emit the activation checkpoint codes. + + Args: + body: forward code + ckpt_func: checkpoint functions code + nodes: graph.nodes + emit_node_func: function to emit node + delete_unused_value_func: function to remove the unused value + """ + ckpt_regions = _find_nested_ckpt_regions(nodes, 0) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + node_list = list(nodes) + + node_idx = 0 + while 1: + # break if we finish the processing all the nodes + if node_idx >= len(node_list): + break + + # process ckpt_regions + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) + node_idx += len(ckpt_node_list) + + # process node in forward function + else: + node = node_list[node_idx] + emit_node_func(node, body) + delete_unused_value_func(node, body) + node_idx += 1 + + def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): # find the activation checkpoint regions ckpt_regions = _find_ckpt_regions(nodes) @@ -384,7 +587,10 @@ if CODEGEN_AVAILABLE: # Modified for activation checkpointing ckpt_func = [] - emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes): + emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + else: + emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -582,7 +788,10 @@ else: # Modified for activation checkpointing ckpt_func = [] - emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes): + emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + else: + emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py new file mode 100644 index 000000000..56f25175e --- /dev/null +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -0,0 +1,153 @@ +import torch +import torch.nn.functional as F +import pytest +import torch.multiprocessing as mp +from torch.utils.checkpoint import checkpoint +from torch.fx import GraphModule +from colossalai.fx import ColoTracer +import colossalai +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +class MyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.linear5 = torch.nn.Linear(4, 4) + self.linear6 = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear6(self.linear5(self.linear4(self.linear3(self.linear2(self.linear1(x)))))) + + +def _run_act_ckpt_codegen(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and run forward + model = MyModule() + data1 = torch.rand(4, 4) + + # copy model to cuda + model = model.to(device="cuda") + data1 = data1.to(device="cuda") + + non_fx_out = model(data1) + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate nested checkpoint + for node in graph.nodes: + if node.name == "linear1": + setattr(node, "activation_checkpoint", [0, 0, 0]) + continue + if node.name == "linear2": + setattr(node, "activation_checkpoint", [0, 0, None]) + if node.name == "linear3": + setattr(node, "activation_checkpoint", [0, 0, 1]) + if node.name == "linear4": + setattr(node, "activation_checkpoint", [0, 1, None]) + if node.name == "linear5": + setattr(node, "activation_checkpoint", 1) + gm = ColoGraphModule(model, graph) + gm.recompile() + + # assert checkpoint function will be generated and + code = graph.python_code('self').src + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + + # recompile and verify the outputs are consistent + fx_out = gm(data1) + assert torch.equal(non_fx_out, fx_out) + + gpc.destroy() + + +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_act_ckpt_codegen(): + mp.spawn(_run_act_ckpt_codegen, nprocs=1) + + +def _run_act_ckpt_python_code_torch11(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and run forward + model = MyModule() + data1 = torch.rand(4, 4) + + # copy model to cuda + model = model.to(device="cuda") + data1 = data1.to(device="cuda") + + non_fx_out = model(data1) + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate nested checkpoint + for node in graph.nodes: + if node.name == "linear1": + setattr(node, "activation_checkpoint", [0, 0, 0]) + continue + if node.name == "linear2": + setattr(node, "activation_checkpoint", [0, 0, None]) + if node.name == "linear3": + setattr(node, "activation_checkpoint", [0, 0, 1]) + if node.name == "linear4": + setattr(node, "activation_checkpoint", [0, 1, None]) + if node.name == "linear5": + setattr(node, "activation_checkpoint", 1) + gm = ColoGraphModule(model, graph) + gm.recompile() + + # assert checkpoint function will be generated and + code = graph.python_code('self').src + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + + # recompile and verify the outputs are consistent + fx_out = gm(data1) + assert torch.equal(non_fx_out, fx_out) + + gpc.destroy() + + +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +def test_act_ckpt_python_code_torch11(): + mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) + + +if __name__ == '__main__': + _run_act_ckpt_codegen(rank=0)