mirror of https://github.com/hpcaitech/ColossalAI
[fx] fix offload codegen test (#1648)
* [fx] fix offload codegen test * [fx] modify typingpull/1654/head
parent
45b39a692a
commit
5d0fdb9cb4
|
@ -1,6 +1,6 @@
|
|||
import colossalai
|
||||
import torch
|
||||
from typing import List, Callable, Any, Tuple, Dict
|
||||
from typing import List, Callable, Any, Tuple, Dict, Iterable
|
||||
|
||||
try:
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
|
@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]):
|
|||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), list):
|
||||
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable):
|
||||
act_offload_label = node.activation_offload
|
||||
|
||||
if current_region == None:
|
||||
|
@ -796,7 +796,7 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
|
||||
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes):
|
||||
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
else:
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
|
@ -999,7 +999,7 @@ else:
|
|||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
|
||||
if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes):
|
||||
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
else:
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
|
|
|
@ -83,13 +83,13 @@ def _run_offload_codegen(rank):
|
|||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", (1, True, True))
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", (2, False, True))
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
|
@ -138,13 +138,13 @@ def _run_offload_codegen_torch11(rank):
|
|||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", (0, True, False))
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", (1, True, True))
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", (2, False, True))
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
|
|
Loading…
Reference in New Issue