remove autochunk_available

pull/2364/head
oahzxl 2023-01-09 16:06:58 +08:00
parent aafc3516a5
commit 19cc64b1d3
1 changed files with 237 additions and 249 deletions

View File

@ -16,13 +16,9 @@ from torch.fx.graph import (
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
import colossalai import colossalai
from .search_chunk import SearchChunk from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
CODEGEN_AVAILABLE = True
__all__ = ["AutoChunkCodeGen"]
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
new_shape = "[" new_shape = "["
@ -222,8 +218,6 @@ def emit_code_with_chunk(
node_idx += 1 node_idx += 1
if CODEGEN_AVAILABLE:
class AutoChunkCodeGen(CodeGen): class AutoChunkCodeGen(CodeGen):
def __init__(self, meta_graph, max_memory=None, print_mem=False): def __init__(self, meta_graph, max_memory=None, print_mem=False):
super().__init__() super().__init__()
@ -271,9 +265,7 @@ if CODEGEN_AVAILABLE:
return global_name return global_name
# set _custom_builtins here so that we needn't import colossalai in forward # set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin( _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
"import colossalai", colossalai
)
# Pre-fill the globals table with registered builtins. # Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items(): for name, (_, obj) in _custom_builtins.items():
@ -373,9 +365,7 @@ if CODEGEN_AVAILABLE:
) )
if node.op == "placeholder": if node.op == "placeholder":
assert isinstance(node.target, str) assert isinstance(node.target, str)
maybe_default_arg = ( maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
"" if not node.args else f" = {repr(node.args[0])}"
)
free_vars.append( free_vars.append(
f"{node.target}{maybe_type_annotation}{maybe_default_arg}" f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
) )
@ -479,9 +469,7 @@ if CODEGEN_AVAILABLE:
if len(wrapped_fns) > 0: if len(wrapped_fns) > 0:
wrap_name = add_global("wrap", torch.fx.wrap) wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join( wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
[f'{wrap_name}("{name}")' for name in wrapped_fns]
)
else: else:
wrap_stmts = "" wrap_stmts = ""