add available

pull/2364/head
oahzxl 2023-01-10 10:56:39 +08:00
parent 615e7e68d9
commit a591d45b29
1 changed files with 266 additions and 254 deletions

View File

@ -1,7 +1,12 @@
from typing import Any, Dict, Iterable, List, Tuple
import torch
from torch.fx.graph import (
import colossalai
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
if CODEGEN_AVAILABLE:
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
@ -12,10 +17,9 @@ from torch.fx.graph import (
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
import colossalai
from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
@ -302,7 +306,9 @@ def emit_code_with_chunk(
node_idx += 1
class AutoChunkCodeGen(CodeGen):
if CODEGEN_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
def __init__(self, meta_graph, max_memory=None, print_mem=False):
super().__init__()
self.meta_graph = meta_graph
@ -349,7 +355,9 @@ class AutoChunkCodeGen(CodeGen):
return global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
_custom_builtins["colossalai"] = _CustomBuiltin(
"import colossalai", colossalai
)
# Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items():
@ -449,7 +457,9 @@ class AutoChunkCodeGen(CodeGen):
)
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
maybe_default_arg = (
"" if not node.args else f" = {repr(node.args[0])}"
)
free_vars.append(
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
)
@ -553,7 +563,9 @@ class AutoChunkCodeGen(CodeGen):
if len(wrapped_fns) > 0:
wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
wrap_stmts = "\n".join(
[f'{wrap_name}("{name}")' for name in wrapped_fns]
)
else:
wrap_stmts = ""
@ -572,9 +584,9 @@ class AutoChunkCodeGen(CodeGen):
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{wrap_stmts}
{prologue}
{code}"""
{prologue}
{code}"""
# print(fn_code)
return PythonCode(fn_code, globals_)