[Inference]Support vllm testing in benchmark scripts (#5379)

* add vllm benchmark scripts

* fix code style

* update run_benchmark.sh

* fix code style
pull/5383/head
yuehuayingxueluo 2024-02-08 15:27:26 +08:00 committed by GitHub
parent 9afa52061f
commit 8c69debdc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 69 additions and 19 deletions

View File

@ -139,6 +139,7 @@ class InferenceEngine:
self, self,
prompts: List[str] = None, prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
) -> List[str]: ) -> List[str]:
""" """
@ -147,6 +148,7 @@ class InferenceEngine:
Args: Args:
prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool): Whether to return output token ids. Defaults to False.
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns: Returns:
@ -158,7 +160,7 @@ class InferenceEngine:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids) self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)
output_seqs_list = [] output_seqs_list = []
output_tokens_list = [] total_tokens_list = []
# intuition: If user provide a generation config, we should replace the existing one. # intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None: if generation_config is not None:
@ -170,10 +172,14 @@ class InferenceEngine:
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
for seq in output_seqs_list: for seq in output_seqs_list:
output_tokens_list.append(seq.input_token_id + seq.output_token_id) total_tokens_list.append(seq.input_token_id + seq.output_token_id)
output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True) output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
if return_token_ids:
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
return output_str, output_tokens_list
else:
return output_str return output_str
@property @property

View File

@ -6,6 +6,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import transformers import transformers
from transformers import AutoTokenizer, GenerationConfig from transformers import AutoTokenizer, GenerationConfig
from vllm import LLM, SamplingParams
import colossalai import colossalai
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
@ -58,12 +59,12 @@ def data_gen(batch_size: int = 4, seq_len: int = 512):
return input_ids return input_ids
def print_details_info(model_config, args, whole_end2end): def print_details_info(model_config, args, whole_end2end, total_token_num):
msg: str = "" msg: str = ""
if dist.get_rank() == 0: if dist.get_rank() == 0:
msg += "-------Perf Summary-------\n" msg += "-------Perf Summary-------\n"
whole_avg_latency = whole_end2end / (args.output_len * args.batch_size) whole_avg_latency = whole_end2end / (total_token_num)
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
if args.dtype in ["fp16", "bf16"]: if args.dtype in ["fp16", "bf16"]:
@ -73,7 +74,7 @@ def print_details_info(model_config, args, whole_end2end):
msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n" msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n"
msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n" msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n"
msg += f"Throughput: {args.output_len * args.batch_size / whole_end2end:.2f} tokens/s\n" msg += f"Throughput: {total_token_num / whole_end2end:.2f} tokens/s\n"
msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n" msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -88,9 +89,15 @@ def benchmark_inference(args):
with torch.no_grad(): with torch.no_grad():
config = CONFIG_MAP[args.model] config = CONFIG_MAP[args.model]
config.pad_token_id = config.eos_token_id config.pad_token_id = config.eos_token_id
if args.test_random_weight:
model = transformers.LlamaForCausalLM(config).cuda() model = transformers.LlamaForCausalLM(config).cuda()
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
else:
assert args.model_path, "When testing pretrained weights, the model path must be provided.'"
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda()
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = model.eval()
if args.dtype == "fp16": if args.dtype == "fp16":
model = model.half() model = model.half()
@ -101,7 +108,7 @@ def benchmark_inference(args):
mbsz = args.mbsz mbsz = args.mbsz
else: else:
mbsz = args.batch_size mbsz = args.batch_size
if args.mode == "caiinference": if args.mode == "colossalai":
inference_config = InferenceConfig( inference_config = InferenceConfig(
dtype=args.dtype, dtype=args.dtype,
micro_batch_size=args.mb_size, micro_batch_size=args.mb_size,
@ -109,12 +116,27 @@ def benchmark_inference(args):
max_input_len=args.seq_len, max_input_len=args.seq_len,
max_output_len=args.output_len, max_output_len=args.output_len,
prefill_ratio=1.2, prefill_ratio=1.2,
block_size=32,
) )
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
elif args.mode == "vllm":
engine = LLM(
model=args.model_path,
max_num_seqs=mbsz,
dtype="float16",
enforce_eager=True,
)
sampling_params = SamplingParams(
max_tokens=args.output_len,
)
else: else:
engine = model engine = model
data = data_gen(mbsz, args.seq_len) data = data_gen(mbsz, args.seq_len)
data = data.tolist()
generation_config = GenerationConfig( generation_config = GenerationConfig(
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
max_new_tokens=args.output_len, max_new_tokens=args.output_len,
@ -132,7 +154,7 @@ def benchmark_inference(args):
torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CUDA,
], ],
schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1), schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler("./tb_log_" + args.mode), on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode),
) )
if args.profile if args.profile
else nullcontext() else nullcontext()
@ -140,8 +162,10 @@ def benchmark_inference(args):
with ctx: with ctx:
for _ in range(N_WARMUP_STEPS): for _ in range(N_WARMUP_STEPS):
if args.mode == "caiinference": if args.mode == "colossalai":
engine.generate(prompts_token_ids=data, generation_config=generation_config) engine.generate(prompts_token_ids=data, generation_config=generation_config)
elif args.mode == "vllm":
engine.generate(prompt_token_ids=data, sampling_params=sampling_params)
else: else:
engine.generate(data, generation_config=generation_config) engine.generate(data, generation_config=generation_config)
if args.profile: if args.profile:
@ -153,19 +177,35 @@ def benchmark_inference(args):
torch.cuda.synchronize() torch.cuda.synchronize()
whole_end2end = time.perf_counter() whole_end2end = time.perf_counter()
if args.mode == "caiinference":
if args.mode == "colossalai":
for _ in range(args.batch_size // mbsz): for _ in range(args.batch_size // mbsz):
engine.generate(prompts_token_ids=data, generation_config=generation_config) output, output_tokens_list = engine.generate(
prompts_token_ids=data, generation_config=generation_config, return_token_ids=True
)
elif args.mode == "vllm":
for _ in range(args.batch_size // mbsz):
output = engine.generate(prompt_token_ids=data, sampling_params=sampling_params)
else: else:
for _ in range(args.batch_size // mbsz): for _ in range(args.batch_size // mbsz):
engine.generate(data, generation_config=generation_config) output = engine.generate(data, generation_config=generation_config)
whole_end2end = time.perf_counter() - whole_end2end whole_end2end = time.perf_counter() - whole_end2end
if args.mode == "colossalai":
total_token_num = sum([len(output_tokens) for output_tokens in output_tokens_list])
elif args.mode == "vllm":
total_token_num = sum([len(out.outputs[0].token_ids) for out in output])
else:
total_token_num = sum([len(out) for out in output])
print("total_token_num: ", total_token_num)
if args.nsys: if args.nsys:
torch.cuda.cudart().cudaProfilerStop() torch.cuda.cudart().cudaProfilerStop()
if args.profile: if args.profile:
ctx.step() ctx.step()
print_details_info(model.config, args, whole_end2end) print_details_info(model.config, args, whole_end2end, total_token_num)
def hybrid_inference(rank, world_size, port, args): def hybrid_inference(rank, world_size, port, args):
@ -188,6 +228,7 @@ if __name__ == "__main__":
help="the size of model", help="the size of model",
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"], choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
) )
parser.add_argument("--model_path", type=str, default=None, help="The pretrained weights path")
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
parser.add_argument("--mbsz", type=int, default=8, help="batch size for one step") parser.add_argument("--mbsz", type=int, default=8, help="batch size for one step")
parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length") parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length")
@ -197,12 +238,15 @@ if __name__ == "__main__":
parser.add_argument("--output_len", type=int, default=128, help="Output length") parser.add_argument("--output_len", type=int, default=128, help="Output length")
parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"]) parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"])
parser.add_argument("-v", "--verbose", default=False, action="store_true") parser.add_argument("-v", "--verbose", default=False, action="store_true")
parser.add_argument(
"--test_random_weight", default=False, action="store_true", help="whether to test random weight"
)
parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler") parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler")
parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler") parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler")
parser.add_argument( parser.add_argument(
"--mode", "--mode",
default="caiinference", default="colossalai",
choices=["caiinference", "transformers"], choices=["colossalai", "transformers", "vllm"],
help="decide which inference framework to run", help="decide which inference framework to run",
) )
parser.add_argument( parser.add_argument(

View File

@ -26,7 +26,7 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
for input_len in 128 512 1024; do for input_len in 128 512 1024; do
for output_len in 128 256; do for output_len in 128 256; do
for bsz in 16 32 64; do for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --test_random_weight | tee logs/${input_len}_${output_len}_${mode}_${GPU}_${bsz}.txt
done done
done done
done done