mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
169 lines
6.0 KiB
169 lines
6.0 KiB
1 year ago
|
import argparse
|
||
|
import time
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
import transformers
|
||
|
|
||
|
import colossalai
|
||
|
import colossalai.utils.device as device_utils
|
||
|
from colossalai.inference import InferenceEngine
|
||
|
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||
|
from colossalai.utils.device import get_current_device
|
||
|
|
||
|
GIGABYTE = 1024**3
|
||
|
MEGABYTE = 1024 * 1024
|
||
|
|
||
|
CONFIG_MAP = {
|
||
|
"toy": transformers.LlamaConfig(num_hidden_layers=4),
|
||
|
"llama-7b": transformers.LlamaConfig(
|
||
|
hidden_size=4096,
|
||
|
intermediate_size=11008,
|
||
|
num_attention_heads=32,
|
||
|
num_hidden_layers=32,
|
||
|
num_key_value_heads=32,
|
||
|
max_position_embeddings=2048,
|
||
|
),
|
||
|
"llama-13b": transformers.LlamaConfig(
|
||
|
hidden_size=5120,
|
||
|
intermediate_size=13824,
|
||
|
num_attention_heads=40,
|
||
|
num_hidden_layers=40,
|
||
|
num_key_value_heads=40,
|
||
|
max_position_embeddings=2048,
|
||
|
),
|
||
|
"llama2-7b": transformers.LlamaConfig(
|
||
|
hidden_size=4096,
|
||
|
intermediate_size=11008,
|
||
|
num_attention_heads=32,
|
||
|
num_hidden_layers=32,
|
||
|
num_key_value_heads=32,
|
||
|
max_position_embeddings=4096,
|
||
|
),
|
||
|
"llama2-13b": transformers.LlamaConfig(
|
||
|
hidden_size=5120,
|
||
|
intermediate_size=13824,
|
||
|
num_attention_heads=40,
|
||
|
num_hidden_layers=40,
|
||
|
num_key_value_heads=40,
|
||
|
max_position_embeddings=4096,
|
||
|
),
|
||
|
}
|
||
|
|
||
|
|
||
|
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||
|
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device())
|
||
|
attention_mask = torch.ones_like(input_ids)
|
||
|
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||
|
return data
|
||
|
|
||
|
|
||
|
def print_details_info(outputs, model_config, args, whole_end2end):
|
||
|
msg: str = ""
|
||
|
|
||
|
if dist.get_rank() == 0:
|
||
|
msg += "-------Perf Summary-------\n"
|
||
|
if args.verbose:
|
||
|
timestamps = outputs[1]
|
||
|
prefill = []
|
||
|
encoder = []
|
||
|
end2end = []
|
||
|
for timestamp in timestamps:
|
||
|
prefill.append(timestamp[1] - timestamp[0])
|
||
|
encoder.append(
|
||
|
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
|
||
|
)
|
||
|
end2end.append(timestamp[-1] - timestamp[0])
|
||
|
|
||
|
mb_avg_end2end = sum(end2end) / len(end2end)
|
||
|
mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size)
|
||
|
|
||
|
msg += f"Average prefill time: {sum(prefill) / len(prefill) * 1000:.2f} ms\n"
|
||
|
msg += f"Average encode time: {sum(encoder) / len(encoder) * 1000:.2f} ms\n"
|
||
|
msg += f"Average micro batch end2end time: {mb_avg_end2end * 1000:.2f} ms\n"
|
||
|
msg += f"Average micro batch per token latency: {mb_avg_latency * 1000:.2f} ms\n"
|
||
|
|
||
|
whole_avg_latency = whole_end2end / (args.output_len * args.batch_size)
|
||
|
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
||
|
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
||
|
if args.dtype in ["fp16", "bf16"]:
|
||
|
num_bytes = 2
|
||
|
else:
|
||
|
num_bytes = 4
|
||
|
|
||
|
msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n"
|
||
|
msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n"
|
||
|
msg += f"Throughput: {args.output_len * args.batch_size / whole_end2end:.2f} tokens/s\n"
|
||
|
msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"
|
||
|
|
||
|
if torch.cuda.is_available():
|
||
|
msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n"
|
||
|
msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n"
|
||
|
msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n"
|
||
|
|
||
|
print(msg)
|
||
|
|
||
|
|
||
|
def benchmark_inference(args):
|
||
|
config = CONFIG_MAP[args.model]
|
||
|
model = transformers.LlamaForCausalLM(config)
|
||
|
if dist.get_rank() == 0:
|
||
|
print("Model loaded")
|
||
|
engine = InferenceEngine(
|
||
|
pp_size=args.pp_size,
|
||
|
tp_size=args.tp_size,
|
||
|
dtype=args.dtype,
|
||
|
micro_batch_size=args.mb_size,
|
||
|
model=model,
|
||
|
verbose=args.verbose,
|
||
|
max_batch_size=args.batch_size,
|
||
|
max_input_len=args.seq_len,
|
||
|
max_output_len=args.output_len,
|
||
|
)
|
||
|
data = data_gen(args.batch_size, args.seq_len)
|
||
|
|
||
|
N_WARMUP_STEPS = 2
|
||
|
|
||
|
for _ in range(N_WARMUP_STEPS):
|
||
|
engine.generate(data)
|
||
|
|
||
|
torch.cuda.synchronize()
|
||
|
whole_end2end = time.time()
|
||
|
outputs = engine.generate(data)
|
||
|
torch.cuda.synchronize()
|
||
|
whole_end2end = time.time() - whole_end2end
|
||
|
|
||
|
print_details_info(outputs, model.config, args, whole_end2end)
|
||
|
|
||
|
|
||
|
def hybrid_inference(rank, world_size, port, args):
|
||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||
|
benchmark_inference(args)
|
||
|
|
||
|
|
||
|
@rerun_if_address_is_in_use()
|
||
|
@clear_cache_before_run()
|
||
|
def benchmark(args):
|
||
|
spawn(hybrid_inference, nprocs=args.tp_size * args.pp_size, args=args)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument(
|
||
|
"-m",
|
||
|
"--model",
|
||
|
default="toy",
|
||
|
help="the size of model",
|
||
|
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
|
||
|
)
|
||
|
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
||
|
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
|
||
|
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
|
||
|
parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
|
||
|
parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
|
||
|
parser.add_argument("--output_len", type=int, default=128, help="Output length")
|
||
|
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
|
||
|
parser.add_argument("-v", "--verbose", default=False, action="store_true")
|
||
|
args = parser.parse_args()
|
||
|
benchmark(args)
|