[autochunk] support vit (#3084)

support vit for autochunk
* support some new ops for vit
* fix some bugs
* add test for vit
pull/3089/head
Xuanlei Zhao 2023-03-10 10:23:26 +08:00 committed by GitHub
parent e58a3c804c
commit 10c61de2f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 445 additions and 57 deletions

View File

@ -63,7 +63,7 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_oup
context = ""
for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) == "split":
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
input_node.name)
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
@ -205,7 +205,7 @@ def _add_node_slice(
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
get_node_shape(chunk_node))
if get_node_name(chunk_node) == "split":
if get_node_name(chunk_node) in ["split", "unbind"]:
split_chunk_slice = ""
for i in range(len(chunk_node.meta['tensor_meta'])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)

View File

@ -74,6 +74,9 @@ class TraceIndice(object):
"""
add a dim for indice, compute and source
"""
# need to remap if dim_idx < 0, e.g. -1
if dim_idx < 0:
dim_idx = list(range(len(self.indice_trace_list[node_idx]["indice"]) + 1))[dim_idx]
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
@ -575,6 +578,60 @@ class TraceIndice(object):
cat_dim = node.kwargs["dim"]
self._del_dim(node_idx, cat_dim)
def _assign_flatten_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for flatten op.
Args:
node (node)
node_idx (int)
"""
nodes_in = node.args[0]
nodes_in_shape = get_node_shape(nodes_in)
flatten_start_dim = node.args[1]
flatten_dim_num = len(nodes_in_shape) - flatten_start_dim - 1
assert flatten_dim_num > 0
for _ in range(flatten_dim_num):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx, nodes_in)
for _ in range(flatten_dim_num + 1):
self._del_dim(node_idx, -1)
self._add_dim(node_idx, -1)
def _assign_expand_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for expand op.
Args:
node (node)
node_idx (int)
"""
expand_shape = node.args[1:]
node_in_shape = get_node_shape(node.args[0])
assert len(expand_shape) == len(node_in_shape)
self._assign_indice_as_input(node, node_idx)
for i in range(len(node_in_shape)):
if expand_shape[i] == node_in_shape[i] or expand_shape[i] == -1:
continue
elif expand_shape[i] > node_in_shape[i]:
self._del_dim(node_idx, i)
self._add_dim(node_idx, i)
else:
raise RuntimeError()
def _assign_unbind_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for unbind op.
Args:
node (node)
node_idx (int)
"""
unbind_dim = node.args[1]
self._add_dim(node_idx, unbind_dim)
self._assign_indice_as_input(node, node_idx)
self._del_dim(node_idx, unbind_dim)
def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
"""
Assign indice for embedding op.
@ -695,32 +752,39 @@ class TraceIndice(object):
shape_idx = target_shape.index(-1)
target_shape[shape_idx] = origin_product // target_product
# determine changed dim
len_diff = len(origin_shape) - len(target_shape)
if len_diff == 1:
# find same dim
dim_to_same_dim = []
dim_from_same_dim = []
for i in range(len(origin_shape)):
if origin_shape[i] == target_shape[i]:
dim_to_same_dim.append(i)
dim_from_same_dim.append(i)
else:
break
for i in range(-1, -len(origin_shape), -1):
if origin_shape[i] == target_shape[i]:
dim_to_same_dim.append(len(target_shape) + i)
dim_from_same_dim.append(len(origin_shape) + i)
else:
break
dim_from = list(set(range(len(origin_shape))) - set(dim_from_same_dim))
dim_to = list(set(range(len(target_shape))) - set(dim_to_same_dim))
assert len(dim_from) == 1 or len(dim_to) == 1 or len(dim_from) == len(dim_to)
dim_diff = len(dim_from) - len(dim_to)
if dim_diff > 0:
# dim merge
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
dim_to = [dim_equal.index(False)]
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
self._add_dim(node_idx, -1)
elif len_diff == -1:
for i in range(dim_diff):
self._add_dim(node_idx, -1)
elif dim_diff < 0:
# dim expand
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
dim_from = [dim_equal.index(False)]
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
self._del_dim(node_idx, -1)
elif len_diff == 0:
# dim equal
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
dim_from = []
dim_to = []
else:
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
for i in range(-dim_diff):
self._del_dim(node_idx, -1)
# get new indice
origin_trace = self._find_indice_trace_from_node(origin_node)
self._assign_indice_as_input(node, node_idx, origin_node)
idx_from = [origin_trace[i] for i in dim_from]
dim_from.reverse()
for i in dim_from:
self._del_dim(node_idx, i)
@ -728,36 +792,18 @@ class TraceIndice(object):
self._add_dim(node_idx, i)
dim_from.reverse()
# search view list
# for view_node, view_dict in self.indice_view_list.items():
# if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
# and view_dict["dim_from"] == dim_to):
# # inheirt indice from current node
# if len_diff == 1:
# if origin_shape[dim_from[0]] == 1:
# self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
# elif origin_shape[dim_from[1]] == 1:
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# elif len_diff == -1:
# if target_shape[dim_to[0]] == 1:
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
# elif target_shape[dim_to[1]] == 1:
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# # inherid indice from input node of last view
# for dim_to_i in dim_to:
# self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
# inheirt indice from current node
if len_diff == 1:
if origin_shape[dim_from[0]] == 1:
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
elif origin_shape[dim_from[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
elif len_diff == -1:
if target_shape[dim_to[0]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
elif target_shape[dim_to[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
if len(dim_from) != 0 and len(dim_to) != 0:
if dim_diff == 1:
if origin_shape[dim_from[0]] == 1:
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
elif origin_shape[dim_from[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
elif dim_diff == -1:
if target_shape[dim_to[0]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
elif target_shape[dim_to[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# log view, not used now
view_dict = {
@ -809,6 +855,14 @@ class TraceIndice(object):
self._assgin_no_change_indice(node, idx)
elif "new_ones" == node_name:
self._assign_all_indice(node, idx)
elif "flatten" == node_name:
self._assign_flatten_indice(node, idx)
elif "expand" == node_name:
self._assign_expand_indice(node, idx)
elif "unbind" == node_name:
self._assign_unbind_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
elif any(i == node_name for i in ["size"]):
continue
else:
@ -859,7 +913,9 @@ class TraceIndice(object):
self._assign_linear_indice(node, idx)
elif "conv2d" == node_name:
self._assign_conv2d_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu"]):
elif "identity" == node_name:
self._assgin_no_change_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
self._assign_elementwise_indice(node, idx)
else:
raise NotImplementedError(node_name, "module not implemented yet!")

View File

@ -109,8 +109,11 @@ def is_non_compute_node(node: Node) -> bool:
return False
def get_node_shape(node: Node) -> List:
if get_node_name(node) == "split":
def get_node_shape(node: Node) -> Any:
"""
return node data shape
"""
if get_node_name(node) in ["split", "unbind"]:
return node.meta["tensor_meta"][0].shape
if hasattr(node.meta["tensor_meta"], "shape"):
return node.meta["tensor_meta"].shape

View File

@ -359,7 +359,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.where.self,
aten.zero_.default,
aten.zeros_like.default,
aten.fill_.Scalar
aten.fill_.Scalar,
aten.stack.default
] # yapf: disable
for op in zero_flop_aten:

View File

@ -0,0 +1,147 @@
import time
from typing import Any, Dict, List
import torch
import torch.fx
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import parameter_size
from colossalai.utils import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def _benchmark_autochunk_unet_gm(
model: Any,
data: tuple,
max_memory: int = None,
) -> None:
model = model.cuda().eval()
# build model and input
meta_args, concrete_args = data
if concrete_args is None:
concrete_args = {}
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
model = model.cuda().eval()
interp = MetaInfoProp(meta_graph)
meta_tensors = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model.cuda().eval(),
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# init inputs
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda().eval()
# bench
para_mem = float(parameter_size(model)) / 1024**2
act_mem = _benchmark_memory(gm, inputs)
speed = _benchmark_speed(gm, inputs)
print("unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" %
(speed, act_mem, para_mem, act_mem + para_mem))
def _benchmark_autochunk_unet_origin(
model: Any,
data: tuple,
) -> None:
# build model and input
meta_args, concrete_args = data
if concrete_args is None:
concrete_args = {}
# init inputs
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda().eval()
# bench
para_mem = float(parameter_size(model)) / 1024**2
act_mem = _benchmark_memory(model, inputs)
speed = _benchmark_speed(model, inputs)
print("unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" %
(speed, act_mem, para_mem, act_mem + para_mem))
return act_mem
def _benchmark_memory(model, inputs):
with torch.no_grad():
torch.cuda.reset_peak_memory_stats()
now_mem = float(torch.cuda.memory_allocated()) / 1024**2
model(*inputs)
new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2
return new_max_mem - now_mem
def _benchmark_speed(model, inputs, loop=5):
with torch.no_grad():
for _ in range(loop // 2 + 1):
model(*inputs)
torch.cuda.synchronize()
time1 = time.time()
for _ in range(loop):
model(*inputs)
torch.cuda.synchronize()
time2 = time.time()
return (time2 - time1) / loop
def benchmark_autochunk_unet(batch=1, height=448, width=448):
from test_autochunk_unet import UNet2DModel, get_data
model = UNet2DModel()
latent_shape = (batch, 3, height // 7, width // 7)
print("\nbatch: %d, height: %d, width: %d" % (batch, height, width))
max_mem = _benchmark_autochunk_unet_origin(model, get_data(latent_shape))
for ratio in [0.5, 0.4, 0.3, 0.2]:
try:
_benchmark_autochunk_unet_gm(model, get_data(latent_shape), max_mem * ratio)
except RuntimeError as e:
if e.args[0] == 'Search failed. Try a larger memory threshold.':
break
except Exception as e:
raise e
_benchmark_autochunk_unet_gm(model, get_data(latent_shape), None)
if __name__ == "__main__":
# launch colossalai
colossalai.launch(
config={},
rank=0,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
benchmark_autochunk_unet(batch=1, height=224 * 2, width=224 * 2)
benchmark_autochunk_unet(batch=1, height=224 * 3, width=224 * 3)
benchmark_autochunk_unet(batch=1, height=224 * 4, width=224 * 4)

View File

@ -39,7 +39,7 @@ def get_data(shape: tuple) -> Tuple[List, List]:
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
@pytest.mark.parametrize("max_memory", [None])
@pytest.mark.parametrize("max_memory", [None, 150, 300])
def test_evoformer_block(model, shape, max_memory):
run_func = partial(
run_test,
@ -57,7 +57,7 @@ if __name__ == "__main__":
max_memory=None,
model=UNet2DModel,
print_code=False,
print_mem=False,
print_mem=True,
print_est_mem=False,
print_progress=False,
)

View File

@ -0,0 +1,53 @@
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from timm.models.vision_transformer import vit_large_patch16_384 as vit
MODELS = [vit]
HAS_REPO = True
except:
MODELS = []
HAS_REPO = False
from test_autochunk_vit_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
def get_data() -> Tuple[List, List]:
data = torch.rand(1, 3, 384, 384)
meta_args = {'x': data}
return data, meta_args
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_memory", [None, 32, 40])
def test_evoformer_block(model, max_memory):
run_func = partial(
run_test,
max_memory=max_memory,
model=model,
data=get_data(),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data(),
max_memory=None,
model=vit,
print_code=False,
print_mem=False,
print_est_mem=False,
print_progress=False,
)

View File

@ -0,0 +1,128 @@
from typing import Any, Dict, List
import torch
import torch.fx
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def assert_codegen_run(
model: Any,
meta_args: Dict,
data: Any,
max_memory: int = None,
print_mem: bool = False,
print_est_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
model = model()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(model, meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()})
model = model.cuda().eval()
interp = MetaInfoProp(meta_graph)
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args.items()]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_est_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model.cuda(),
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert chunk in code
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_size = None; " in code
# assert result
inputs = [data.cuda()]
model.cuda().eval()
gm.eval()
with torch.no_grad():
if print_mem:
torch.cuda.reset_peak_memory_stats()
now_mem_gm = torch.cuda.memory_allocated() / 1024**2
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2
torch.cuda.reset_peak_memory_stats()
now_mem_ori = torch.cuda.memory_allocated() / 1024**2
out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2
print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm))
assert torch.allclose(out_gm, out_model,
atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(out_gm - out_model))
return chunks
def run_test(
rank: int,
model: Any,
data: tuple,
max_memory: int,
print_code: bool = False,
print_mem: bool = False,
print_est_mem: bool = False,
print_progress: bool = False,
get_chunk_target: Any = None,
) -> None:
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
data, meta_args = data
chunks = assert_codegen_run(
model,
meta_args=meta_args,
data=data,
max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_est_mem=print_est_mem,
print_progress=print_progress,
)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]
chunk_target = get_chunk_target()[max_memory]
assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
str(chunk_found),
str(chunk_target),
)
gpc.destroy()