mirror of https://github.com/hpcaitech/ColossalAI
remove autochunk_available
parent
aafc3516a5
commit
19cc64b1d3
|
@ -16,13 +16,9 @@ from torch.fx.graph import (
|
|||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
|
||||
import colossalai
|
||||
|
||||
from .search_chunk import SearchChunk
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
||||
|
||||
CODEGEN_AVAILABLE = True
|
||||
__all__ = ["AutoChunkCodeGen"]
|
||||
|
||||
|
||||
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
|
||||
new_shape = "["
|
||||
|
@ -222,9 +218,7 @@ def emit_code_with_chunk(
|
|||
node_idx += 1
|
||||
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
|
||||
class AutoChunkCodeGen(CodeGen):
|
||||
class AutoChunkCodeGen(CodeGen):
|
||||
def __init__(self, meta_graph, max_memory=None, print_mem=False):
|
||||
super().__init__()
|
||||
self.meta_graph = meta_graph
|
||||
|
@ -271,9 +265,7 @@ if CODEGEN_AVAILABLE:
|
|||
return global_name
|
||||
|
||||
# set _custom_builtins here so that we needn't import colossalai in forward
|
||||
_custom_builtins["colossalai"] = _CustomBuiltin(
|
||||
"import colossalai", colossalai
|
||||
)
|
||||
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
|
||||
|
||||
# Pre-fill the globals table with registered builtins.
|
||||
for name, (_, obj) in _custom_builtins.items():
|
||||
|
@ -373,9 +365,7 @@ if CODEGEN_AVAILABLE:
|
|||
)
|
||||
if node.op == "placeholder":
|
||||
assert isinstance(node.target, str)
|
||||
maybe_default_arg = (
|
||||
"" if not node.args else f" = {repr(node.args[0])}"
|
||||
)
|
||||
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
|
||||
free_vars.append(
|
||||
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
|
||||
)
|
||||
|
@ -479,9 +469,7 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
if len(wrapped_fns) > 0:
|
||||
wrap_name = add_global("wrap", torch.fx.wrap)
|
||||
wrap_stmts = "\n".join(
|
||||
[f'{wrap_name}("{name}")' for name in wrapped_fns]
|
||||
)
|
||||
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
|
||||
else:
|
||||
wrap_stmts = ""
|
||||
|
||||
|
|
Loading…
Reference in New Issue