mirror of https://github.com/hpcaitech/ColossalAI
oahzxl
2 years ago
committed by
GitHub
10 changed files with 286 additions and 5 deletions
@ -0,0 +1,131 @@ |
|||||||
|
import time |
||||||
|
from typing import Any, Dict, List |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.fx |
||||||
|
|
||||||
|
import colossalai |
||||||
|
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE |
||||||
|
from colossalai.fx.graph_module import ColoGraphModule |
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp |
||||||
|
from colossalai.utils import free_port |
||||||
|
|
||||||
|
if AUTOCHUNK_AVAILABLE: |
||||||
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen |
||||||
|
from colossalai.fx.profiler import MetaTensor |
||||||
|
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace |
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_evoformer_stack_gm( |
||||||
|
data_args: tuple, |
||||||
|
max_memory: int, |
||||||
|
get_model: Any, |
||||||
|
get_data: Any, |
||||||
|
) -> None: |
||||||
|
# build model and input |
||||||
|
model = get_model() |
||||||
|
meta_args, concrete_args = get_data(*data_args) |
||||||
|
if concrete_args is None: |
||||||
|
concrete_args = [] |
||||||
|
|
||||||
|
# trace the meta graph and setup codegen |
||||||
|
meta_graph = symbolic_trace( |
||||||
|
model, |
||||||
|
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, |
||||||
|
concrete_args={k: v for k, v in concrete_args}, |
||||||
|
) |
||||||
|
interp = MetaInfoProp(meta_graph) |
||||||
|
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args] |
||||||
|
interp.propagate(*meta_tensors) |
||||||
|
codegen = AutoChunkCodeGen( |
||||||
|
meta_graph, |
||||||
|
max_memory=max_memory, |
||||||
|
) |
||||||
|
|
||||||
|
# trace and recompile |
||||||
|
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer |
||||||
|
graph = ColoTracer().trace( |
||||||
|
model, |
||||||
|
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args}, |
||||||
|
concrete_args={k: v for k, v in concrete_args}, |
||||||
|
) |
||||||
|
graph.set_codegen(codegen) |
||||||
|
gm = ColoGraphModule(model, graph, ckpt_codegen=False) |
||||||
|
gm.recompile() |
||||||
|
|
||||||
|
# init inputs |
||||||
|
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] |
||||||
|
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] |
||||||
|
model.cuda() |
||||||
|
|
||||||
|
# bench |
||||||
|
mem = _benchmark_memory(gm, inputs) |
||||||
|
speed = _benchmark_speed(gm, inputs) |
||||||
|
print("evoformer stack gm, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args))) |
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_evoformer_stack_origin( |
||||||
|
data_args: tuple, |
||||||
|
get_model: Any, |
||||||
|
get_data: Any, |
||||||
|
) -> None: |
||||||
|
# build model and input |
||||||
|
model = get_model() |
||||||
|
meta_args, concrete_args = get_data(*data_args) |
||||||
|
if concrete_args is None: |
||||||
|
concrete_args = [] |
||||||
|
|
||||||
|
# init inputs |
||||||
|
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args] |
||||||
|
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] |
||||||
|
model.cuda() |
||||||
|
|
||||||
|
# bench |
||||||
|
mem = _benchmark_memory(model, inputs) |
||||||
|
speed = _benchmark_speed(model, inputs) |
||||||
|
print("evoformer stack origin, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args))) |
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_memory(model, inputs): |
||||||
|
with torch.no_grad(): |
||||||
|
torch.cuda.reset_peak_memory_stats() |
||||||
|
now_mem = torch.cuda.memory_allocated() / 1024**2 |
||||||
|
model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) |
||||||
|
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 |
||||||
|
return new_max_mem - now_mem |
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_speed(model, inputs, loop=5): |
||||||
|
with torch.no_grad(): |
||||||
|
for _ in range(loop // 2 + 1): |
||||||
|
model(*inputs) |
||||||
|
torch.cuda.synchronize() |
||||||
|
time1 = time.time() |
||||||
|
for _ in range(loop): |
||||||
|
model(*inputs) |
||||||
|
torch.cuda.synchronize() |
||||||
|
time2 = time.time() |
||||||
|
return (time2 - time1) / loop |
||||||
|
|
||||||
|
|
||||||
|
def benchmark_evoformer_stack(): |
||||||
|
from test_autochunk_evoformer_stack import get_data, get_model |
||||||
|
data_args = [128, 256] |
||||||
|
print("") |
||||||
|
_benchmark_evoformer_stack_origin(data_args, get_model, get_data) |
||||||
|
_benchmark_evoformer_stack_gm(data_args, 600, get_model, get_data) |
||||||
|
_benchmark_evoformer_stack_gm(data_args, 400, get_model, get_data) |
||||||
|
_benchmark_evoformer_stack_gm(data_args, None, get_model, get_data) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
# launch colossalai |
||||||
|
colossalai.launch( |
||||||
|
config={}, |
||||||
|
rank=0, |
||||||
|
world_size=1, |
||||||
|
host="localhost", |
||||||
|
port=free_port(), |
||||||
|
backend="nccl", |
||||||
|
) |
||||||
|
benchmark_evoformer_stack() |
@ -0,0 +1,150 @@ |
|||||||
|
import time |
||||||
|
from typing import Any, Dict, List |
||||||
|
|
||||||
|
import torch |
||||||
|
import torch.fx |
||||||
|
|
||||||
|
import colossalai |
||||||
|
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE |
||||||
|
from colossalai.fx.graph_module import ColoGraphModule |
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp |
||||||
|
from colossalai.fx.profiler import parameter_size |
||||||
|
from colossalai.utils import free_port |
||||||
|
|
||||||
|
if AUTOCHUNK_AVAILABLE: |
||||||
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen |
||||||
|
from colossalai.fx.profiler import MetaTensor |
||||||
|
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace |
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_autochunk_gpt_gm( |
||||||
|
model: Any, |
||||||
|
data: tuple, |
||||||
|
max_memory: int = None, |
||||||
|
) -> None: |
||||||
|
model = model.cuda().eval() |
||||||
|
|
||||||
|
# build model and input |
||||||
|
meta_args, concrete_args, sequence = data |
||||||
|
if concrete_args is None: |
||||||
|
concrete_args = {} |
||||||
|
|
||||||
|
# trace the meta graph and setup codegen |
||||||
|
meta_graph = symbolic_trace( |
||||||
|
model, |
||||||
|
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, |
||||||
|
concrete_args={k: v for k, v in concrete_args.items()}, |
||||||
|
) |
||||||
|
interp = MetaInfoProp(meta_graph) |
||||||
|
meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] |
||||||
|
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors] |
||||||
|
interp.propagate(*meta_tensors) |
||||||
|
codegen = AutoChunkCodeGen( |
||||||
|
meta_graph, |
||||||
|
max_memory=max_memory, |
||||||
|
) |
||||||
|
|
||||||
|
# trace and recompile |
||||||
|
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer |
||||||
|
graph = ColoTracer().trace( |
||||||
|
model.cuda().eval(), |
||||||
|
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()}, |
||||||
|
concrete_args={k: v for k, v in concrete_args.items()}, |
||||||
|
) |
||||||
|
graph.set_codegen(codegen) |
||||||
|
gm = ColoGraphModule(model, graph, ckpt_codegen=False) |
||||||
|
gm.recompile() |
||||||
|
|
||||||
|
# init inputs |
||||||
|
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] |
||||||
|
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] |
||||||
|
model.cuda().eval() |
||||||
|
|
||||||
|
# bench |
||||||
|
para_mem = float(parameter_size(model)) / 1024**2 * 6 |
||||||
|
act_mem = _benchmark_memory(gm, inputs) |
||||||
|
speed = _benchmark_speed(gm, inputs) |
||||||
|
print("gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % |
||||||
|
(speed, act_mem, para_mem, act_mem + para_mem)) |
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_autochunk_gpt_origin( |
||||||
|
model: Any, |
||||||
|
data: tuple, |
||||||
|
) -> None: |
||||||
|
# build model and input |
||||||
|
meta_args, concrete_args, sequence = data |
||||||
|
if concrete_args is None: |
||||||
|
concrete_args = {} |
||||||
|
|
||||||
|
# init inputs |
||||||
|
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] |
||||||
|
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs] |
||||||
|
model.cuda().eval() |
||||||
|
|
||||||
|
# bench |
||||||
|
para_mem = float(parameter_size(model)) / 1024**2 * 6 |
||||||
|
act_mem = _benchmark_memory(model, inputs) |
||||||
|
speed = _benchmark_speed(model, inputs) |
||||||
|
print("gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % |
||||||
|
(speed, act_mem, para_mem, act_mem + para_mem)) |
||||||
|
return act_mem |
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_memory(model, inputs): |
||||||
|
with torch.no_grad(): |
||||||
|
torch.cuda.reset_peak_memory_stats() |
||||||
|
now_mem = float(torch.cuda.memory_allocated()) / 1024**2 |
||||||
|
model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) |
||||||
|
new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2 |
||||||
|
return new_max_mem - now_mem |
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_speed(model, inputs, loop=5): |
||||||
|
with torch.no_grad(): |
||||||
|
for _ in range(loop // 2 + 1): |
||||||
|
model(*inputs) |
||||||
|
torch.cuda.synchronize() |
||||||
|
time1 = time.time() |
||||||
|
for _ in range(loop): |
||||||
|
model(*inputs) |
||||||
|
torch.cuda.synchronize() |
||||||
|
time2 = time.time() |
||||||
|
return (time2 - time1) / loop |
||||||
|
|
||||||
|
|
||||||
|
def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): |
||||||
|
from test_autochunk_gpt import GPT2Config, GPT2Model, get_data |
||||||
|
model = GPT2Model |
||||||
|
config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=n_head) |
||||||
|
config.max_position_embeddings = seq |
||||||
|
model = model(config=config) |
||||||
|
shape = [batch, seq] |
||||||
|
print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head)) |
||||||
|
max_mem = _benchmark_autochunk_gpt_origin(model, get_data(shape)) |
||||||
|
for ratio in [0.5, 0.4, 0.3, 0.2]: |
||||||
|
try: |
||||||
|
_benchmark_autochunk_gpt_gm(model, get_data(shape), max_mem * ratio) |
||||||
|
except RuntimeError as e: |
||||||
|
if e.args[0] == 'Search failed. Try a larger memory threshold.': |
||||||
|
break |
||||||
|
except Exception as e: |
||||||
|
raise e |
||||||
|
_benchmark_autochunk_gpt_gm(model, get_data(shape), None) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
# launch colossalai |
||||||
|
colossalai.launch( |
||||||
|
config={}, |
||||||
|
rank=0, |
||||||
|
world_size=1, |
||||||
|
host="localhost", |
||||||
|
port=free_port(), |
||||||
|
backend="nccl", |
||||||
|
) |
||||||
|
benchmark_autochunk_gpt(batch=1, seq=1024, n_embd=768, n_head=12) |
||||||
|
benchmark_autochunk_gpt(batch=1, seq=2048, n_embd=768, n_head=12) |
||||||
|
benchmark_autochunk_gpt(batch=1, seq=4096, n_embd=768, n_head=12) |
||||||
|
benchmark_autochunk_gpt(batch=1, seq=6144, n_embd=768, n_head=12) |
||||||
|
benchmark_autochunk_gpt(batch=1, seq=8192, n_embd=768, n_head=12) |
Loading…
Reference in new issue