mirror of https://github.com/hpcaitech/ColossalAI
add available
parent
615e7e68d9
commit
a591d45b29
|
@ -1,7 +1,12 @@
|
||||||
from typing import Any, Dict, Iterable, List, Tuple
|
from typing import Any, Dict, Iterable, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.graph import (
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||||
|
|
||||||
|
if CODEGEN_AVAILABLE:
|
||||||
|
from torch.fx.graph import (
|
||||||
CodeGen,
|
CodeGen,
|
||||||
PythonCode,
|
PythonCode,
|
||||||
_custom_builtins,
|
_custom_builtins,
|
||||||
|
@ -12,10 +17,9 @@ from torch.fx.graph import (
|
||||||
_origin_type_map,
|
_origin_type_map,
|
||||||
inplace_methods,
|
inplace_methods,
|
||||||
magic_methods,
|
magic_methods,
|
||||||
)
|
)
|
||||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from .search_chunk import SearchChunk
|
from .search_chunk import SearchChunk
|
||||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
||||||
|
|
||||||
|
@ -302,7 +306,9 @@ def emit_code_with_chunk(
|
||||||
node_idx += 1
|
node_idx += 1
|
||||||
|
|
||||||
|
|
||||||
class AutoChunkCodeGen(CodeGen):
|
if CODEGEN_AVAILABLE:
|
||||||
|
|
||||||
|
class AutoChunkCodeGen(CodeGen):
|
||||||
def __init__(self, meta_graph, max_memory=None, print_mem=False):
|
def __init__(self, meta_graph, max_memory=None, print_mem=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.meta_graph = meta_graph
|
self.meta_graph = meta_graph
|
||||||
|
@ -349,7 +355,9 @@ class AutoChunkCodeGen(CodeGen):
|
||||||
return global_name
|
return global_name
|
||||||
|
|
||||||
# set _custom_builtins here so that we needn't import colossalai in forward
|
# set _custom_builtins here so that we needn't import colossalai in forward
|
||||||
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
|
_custom_builtins["colossalai"] = _CustomBuiltin(
|
||||||
|
"import colossalai", colossalai
|
||||||
|
)
|
||||||
|
|
||||||
# Pre-fill the globals table with registered builtins.
|
# Pre-fill the globals table with registered builtins.
|
||||||
for name, (_, obj) in _custom_builtins.items():
|
for name, (_, obj) in _custom_builtins.items():
|
||||||
|
@ -449,7 +457,9 @@ class AutoChunkCodeGen(CodeGen):
|
||||||
)
|
)
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
assert isinstance(node.target, str)
|
assert isinstance(node.target, str)
|
||||||
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
|
maybe_default_arg = (
|
||||||
|
"" if not node.args else f" = {repr(node.args[0])}"
|
||||||
|
)
|
||||||
free_vars.append(
|
free_vars.append(
|
||||||
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
|
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
|
||||||
)
|
)
|
||||||
|
@ -553,7 +563,9 @@ class AutoChunkCodeGen(CodeGen):
|
||||||
|
|
||||||
if len(wrapped_fns) > 0:
|
if len(wrapped_fns) > 0:
|
||||||
wrap_name = add_global("wrap", torch.fx.wrap)
|
wrap_name = add_global("wrap", torch.fx.wrap)
|
||||||
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
|
wrap_stmts = "\n".join(
|
||||||
|
[f'{wrap_name}("{name}")' for name in wrapped_fns]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
wrap_stmts = ""
|
wrap_stmts = ""
|
||||||
|
|
||||||
|
@ -572,9 +584,9 @@ class AutoChunkCodeGen(CodeGen):
|
||||||
code = "".join(body)
|
code = "".join(body)
|
||||||
code = "\n".join(" " + line for line in code.split("\n"))
|
code = "\n".join(" " + line for line in code.split("\n"))
|
||||||
fn_code = f"""
|
fn_code = f"""
|
||||||
{wrap_stmts}
|
{wrap_stmts}
|
||||||
|
|
||||||
{prologue}
|
{prologue}
|
||||||
{code}"""
|
{code}"""
|
||||||
# print(fn_code)
|
# print(fn_code)
|
||||||
return PythonCode(fn_code, globals_)
|
return PythonCode(fn_code, globals_)
|
||||||
|
|
Loading…
Reference in New Issue