diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 5bb04a68e..684028c01 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -1,6 +1,6 @@ import colossalai import torch -from typing import List, Callable, Any, Tuple, Dict +from typing import List, Callable, Any, Tuple, Dict, Iterable try: from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name @@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]): current_region = None for idx, node in enumerate(nodes): - if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), list): + if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable): act_offload_label = node.activation_offload if current_region == None: @@ -796,7 +796,7 @@ if CODEGEN_AVAILABLE: # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes): + if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) else: emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) @@ -999,7 +999,7 @@ else: # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes): + if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) else: emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py index 271d86a80..edaeb50cb 100644 --- a/tests/test_fx/test_codegen/test_offload_codegen.py +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -83,13 +83,13 @@ def _run_offload_codegen(rank): # of input offload for node in graph.nodes: if node.name == "linear0": - setattr(node, "activation_offload", (0, True, False)) + setattr(node, "activation_offload", [0, True, False]) if node.name == "linear1": - setattr(node, "activation_offload", (0, True, False)) + setattr(node, "activation_offload", [0, True, False]) if node.name == "linear2": - setattr(node, "activation_offload", (1, True, True)) + setattr(node, "activation_offload", [1, True, True]) if node.name == "linear4": - setattr(node, "activation_offload", (2, False, True)) + setattr(node, "activation_offload", [2, False, True]) if node.name == "linear5": setattr(node, "activation_checkpoint", [0]) setattr(node, "activation_offload", True) @@ -138,13 +138,13 @@ def _run_offload_codegen_torch11(rank): # of input offload for node in graph.nodes: if node.name == "linear0": - setattr(node, "activation_offload", (0, True, False)) + setattr(node, "activation_offload", [0, True, False]) if node.name == "linear1": - setattr(node, "activation_offload", (0, True, False)) + setattr(node, "activation_offload", [0, True, False]) if node.name == "linear2": - setattr(node, "activation_offload", (1, True, True)) + setattr(node, "activation_offload", [1, True, True]) if node.name == "linear4": - setattr(node, "activation_offload", (2, False, True)) + setattr(node, "activation_offload", [2, False, True]) if node.name == "linear5": setattr(node, "activation_checkpoint", [0]) setattr(node, "activation_offload", True)