mirror of https://github.com/hpcaitech/ColossalAI
[inference] refactor examples and fix schedule (#5077)
* [setup] refactor infer setup * [hotfix] fix infenrece behavior on 1 1 gpu * [exmaple] refactor inference examplespull/5087/head
parent
4e3959d316
commit
1cd7efc520
|
@ -33,13 +33,16 @@ class InferenceEngine:
|
||||||
Args:
|
Args:
|
||||||
tp_size (int): the size of tensor parallelism.
|
tp_size (int): the size of tensor parallelism.
|
||||||
pp_size (int): the size of pipeline parallelism.
|
pp_size (int): the size of pipeline parallelism.
|
||||||
|
dtype (str): the data type of the model, should be one of 'fp16', 'fp32', 'bf16'.
|
||||||
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
|
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
|
||||||
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
|
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||||
micro_batch_size (int): the micro batch size.
|
micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
|
||||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||||
max_batch_size (int): the maximum batch size.
|
max_batch_size (int): the maximum batch size.
|
||||||
max_input_len (int): the maximum input length.
|
max_input_len (int): the maximum input length.
|
||||||
max_output_len (int): the maximum output length.
|
max_output_len (int): the maximum output length.
|
||||||
|
quant (str): the quantization method, should be one of 'smoothquant', 'gptq', None.
|
||||||
|
verbose (bool): whether to return the time cost of each step.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -69,6 +69,8 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
batch = tree_map(partial(to_device, device=device), batch)
|
batch = tree_map(partial(to_device, device=device), batch)
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
self.batch_size = get_batch_size(batch)
|
self.batch_size = get_batch_size(batch)
|
||||||
|
if self.stage_manager.num_stages == 1:
|
||||||
|
self.microbatch_size = self.batch_size
|
||||||
self.microbatch_offset = 0
|
self.microbatch_offset = 0
|
||||||
assert (
|
assert (
|
||||||
self.batch_size % self.microbatch_size == 0
|
self.batch_size % self.microbatch_size == 0
|
||||||
|
|
|
@ -1,19 +0,0 @@
|
||||||
def print_perf_stats(latency_set, config, bs, warmup=3):
|
|
||||||
# trim warmup queries
|
|
||||||
latency_set = list(latency_set)
|
|
||||||
latency_set = latency_set[warmup:]
|
|
||||||
count = len(latency_set)
|
|
||||||
|
|
||||||
if count > 0:
|
|
||||||
latency_set.sort()
|
|
||||||
avg = sum(latency_set) / count
|
|
||||||
num_layers = (
|
|
||||||
getattr(config, "num_layers") if hasattr(config, "num_layers") else getattr(config, "num_hidden_layers")
|
|
||||||
)
|
|
||||||
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
|
|
||||||
num_bytes = 2 # float16
|
|
||||||
|
|
||||||
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
|
|
||||||
print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
|
|
||||||
print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
|
|
||||||
print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))
|
|
|
@ -1,167 +0,0 @@
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.inference import CaiInferEngine
|
|
||||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
|
||||||
|
|
||||||
GIGABYTE = 1024**3
|
|
||||||
MEGABYTE = 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
|
||||||
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
|
|
||||||
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
|
|
||||||
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
|
||||||
for k, v in data.items():
|
|
||||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
|
||||||
new_shape = [1] * v.dim()
|
|
||||||
new_shape[0] = batch_size
|
|
||||||
data[k] = v.to("cuda").repeat(*new_shape)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def print_details_info(timestamps, model_config, args, whole_end2end):
|
|
||||||
log_file_name = f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.output_len}_bsz{args.batch_size}_mbsz{args.mb_size}.log"
|
|
||||||
os.makedirs(os.path.dirname(log_file_name), exist_ok=True)
|
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
prefill = []
|
|
||||||
encoder = []
|
|
||||||
end2end = []
|
|
||||||
for timestamp in timestamps:
|
|
||||||
prefill.append(timestamp[1] - timestamp[0])
|
|
||||||
encoder.append(
|
|
||||||
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
|
|
||||||
)
|
|
||||||
end2end.append(timestamp[-1] - timestamp[0])
|
|
||||||
print(whole_end2end)
|
|
||||||
|
|
||||||
with open(
|
|
||||||
log_file_name,
|
|
||||||
"w+",
|
|
||||||
) as f:
|
|
||||||
mb_avg_end2end = sum(end2end) / len(end2end)
|
|
||||||
mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size)
|
|
||||||
whole_avg_latency = whole_end2end / (args.output_len * args.batch_size)
|
|
||||||
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
|
||||||
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
|
||||||
if args.dtype in ["fp16", "bf16"]:
|
|
||||||
num_bytes = 2
|
|
||||||
else:
|
|
||||||
num_bytes = 4
|
|
||||||
|
|
||||||
f.write(
|
|
||||||
f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.output_len}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n"
|
|
||||||
)
|
|
||||||
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000))
|
|
||||||
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000))
|
|
||||||
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000))
|
|
||||||
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
|
|
||||||
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000))
|
|
||||||
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
|
|
||||||
f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000))))
|
|
||||||
f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12))
|
|
||||||
f.write("----------------------------------------------------------\n")
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
current_device = torch.cuda.current_device()
|
|
||||||
|
|
||||||
# free memory and the total available memory in bytes
|
|
||||||
global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info()
|
|
||||||
memory_allocated = torch.cuda.memory_allocated()
|
|
||||||
max_memory_allocated = torch.cuda.max_memory_allocated()
|
|
||||||
memory_reserved = torch.cuda.memory_reserved()
|
|
||||||
max_memory_reserved = torch.cuda.max_memory_reserved()
|
|
||||||
with open(
|
|
||||||
log_file_name,
|
|
||||||
"a",
|
|
||||||
) as f:
|
|
||||||
f.write(
|
|
||||||
f"\nCurrently using GPU: {current_device}\n"
|
|
||||||
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
|
|
||||||
f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n"
|
|
||||||
f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n"
|
|
||||||
f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n"
|
|
||||||
f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n"
|
|
||||||
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_inference(args):
|
|
||||||
if args.model == "toy":
|
|
||||||
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4))
|
|
||||||
elif args.model == "7b":
|
|
||||||
model = transformers.LlamaForCausalLM(
|
|
||||||
transformers.LlamaConfig(
|
|
||||||
hidden_size=4096,
|
|
||||||
intermediate_size=11008,
|
|
||||||
num_attention_heads=32,
|
|
||||||
num_hidden_layers=32,
|
|
||||||
num_key_value_heads=32,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif args.model == "13b":
|
|
||||||
model = transformers.LlamaForCausalLM(
|
|
||||||
transformers.LlamaConfig(
|
|
||||||
hidden_size=5120,
|
|
||||||
intermediate_size=13824,
|
|
||||||
num_attention_heads=40,
|
|
||||||
num_hidden_layers=40,
|
|
||||||
num_key_value_heads=40,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
engine = CaiInferEngine(
|
|
||||||
pp_size=args.pp_size,
|
|
||||||
tp_size=args.tp_size,
|
|
||||||
dtype=args.dtype,
|
|
||||||
micro_batch_size=args.mb_size,
|
|
||||||
model=model,
|
|
||||||
verbose=True,
|
|
||||||
max_batch_size=args.mb_size,
|
|
||||||
max_input_len=args.seq_len,
|
|
||||||
max_output_len=args.output_len,
|
|
||||||
)
|
|
||||||
data = data_gen(args.batch_size, args.seq_len)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
whole_end2end = time.time()
|
|
||||||
output, timestamps = engine.generate(data)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
whole_end2end = time.time() - whole_end2end
|
|
||||||
|
|
||||||
print_details_info(timestamps, model.config, args, whole_end2end)
|
|
||||||
|
|
||||||
|
|
||||||
def hybrid_inference(rank, world_size, port, args):
|
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
||||||
benchmark_inference(args)
|
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
@clear_cache_before_run()
|
|
||||||
def benchmark(args):
|
|
||||||
spawn(hybrid_inference, nprocs=args.tp_size * args.pp_size, args=args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--model", default="toy", help="the size of model")
|
|
||||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
|
||||||
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
|
|
||||||
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
|
|
||||||
parser.add_argument("--pp_size", type=int, default=2, help="pipeline size")
|
|
||||||
parser.add_argument("--tp_size", type=int, default=2, help="pipeline size")
|
|
||||||
parser.add_argument("--output_len", type=int, default=16, help="Output length")
|
|
||||||
parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log")
|
|
||||||
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
|
|
||||||
args = parser.parse_args()
|
|
||||||
benchmark(args)
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
import colossalai.utils.device as device_utils
|
||||||
|
from colossalai.inference import InferenceEngine
|
||||||
|
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
|
GIGABYTE = 1024**3
|
||||||
|
MEGABYTE = 1024 * 1024
|
||||||
|
|
||||||
|
CONFIG_MAP = {
|
||||||
|
"toy": transformers.LlamaConfig(num_hidden_layers=4),
|
||||||
|
"llama-7b": transformers.LlamaConfig(
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
),
|
||||||
|
"llama-13b": transformers.LlamaConfig(
|
||||||
|
hidden_size=5120,
|
||||||
|
intermediate_size=13824,
|
||||||
|
num_attention_heads=40,
|
||||||
|
num_hidden_layers=40,
|
||||||
|
num_key_value_heads=40,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
),
|
||||||
|
"llama2-7b": transformers.LlamaConfig(
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
),
|
||||||
|
"llama2-13b": transformers.LlamaConfig(
|
||||||
|
hidden_size=5120,
|
||||||
|
intermediate_size=13824,
|
||||||
|
num_attention_heads=40,
|
||||||
|
num_hidden_layers=40,
|
||||||
|
num_key_value_heads=40,
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||||
|
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device())
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def print_details_info(outputs, model_config, args, whole_end2end):
|
||||||
|
msg: str = ""
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
msg += "-------Perf Summary-------\n"
|
||||||
|
if args.verbose:
|
||||||
|
timestamps = outputs[1]
|
||||||
|
prefill = []
|
||||||
|
encoder = []
|
||||||
|
end2end = []
|
||||||
|
for timestamp in timestamps:
|
||||||
|
prefill.append(timestamp[1] - timestamp[0])
|
||||||
|
encoder.append(
|
||||||
|
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
|
||||||
|
)
|
||||||
|
end2end.append(timestamp[-1] - timestamp[0])
|
||||||
|
|
||||||
|
mb_avg_end2end = sum(end2end) / len(end2end)
|
||||||
|
mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size)
|
||||||
|
|
||||||
|
msg += f"Average prefill time: {sum(prefill) / len(prefill) * 1000:.2f} ms\n"
|
||||||
|
msg += f"Average encode time: {sum(encoder) / len(encoder) * 1000:.2f} ms\n"
|
||||||
|
msg += f"Average micro batch end2end time: {mb_avg_end2end * 1000:.2f} ms\n"
|
||||||
|
msg += f"Average micro batch per token latency: {mb_avg_latency * 1000:.2f} ms\n"
|
||||||
|
|
||||||
|
whole_avg_latency = whole_end2end / (args.output_len * args.batch_size)
|
||||||
|
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
||||||
|
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
||||||
|
if args.dtype in ["fp16", "bf16"]:
|
||||||
|
num_bytes = 2
|
||||||
|
else:
|
||||||
|
num_bytes = 4
|
||||||
|
|
||||||
|
msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n"
|
||||||
|
msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n"
|
||||||
|
msg += f"Throughput: {args.output_len * args.batch_size / whole_end2end:.2f} tokens/s\n"
|
||||||
|
msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n"
|
||||||
|
msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n"
|
||||||
|
msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n"
|
||||||
|
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_inference(args):
|
||||||
|
config = CONFIG_MAP[args.model]
|
||||||
|
model = transformers.LlamaForCausalLM(config)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
print("Model loaded")
|
||||||
|
engine = InferenceEngine(
|
||||||
|
pp_size=args.pp_size,
|
||||||
|
tp_size=args.tp_size,
|
||||||
|
dtype=args.dtype,
|
||||||
|
micro_batch_size=args.mb_size,
|
||||||
|
model=model,
|
||||||
|
verbose=args.verbose,
|
||||||
|
max_batch_size=args.batch_size,
|
||||||
|
max_input_len=args.seq_len,
|
||||||
|
max_output_len=args.output_len,
|
||||||
|
)
|
||||||
|
data = data_gen(args.batch_size, args.seq_len)
|
||||||
|
|
||||||
|
N_WARMUP_STEPS = 2
|
||||||
|
|
||||||
|
for _ in range(N_WARMUP_STEPS):
|
||||||
|
engine.generate(data)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
whole_end2end = time.time()
|
||||||
|
outputs = engine.generate(data)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
whole_end2end = time.time() - whole_end2end
|
||||||
|
|
||||||
|
print_details_info(outputs, model.config, args, whole_end2end)
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_inference(rank, world_size, port, args):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
benchmark_inference(args)
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def benchmark(args):
|
||||||
|
spawn(hybrid_inference, nprocs=args.tp_size * args.pp_size, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-m",
|
||||||
|
"--model",
|
||||||
|
default="toy",
|
||||||
|
help="the size of model",
|
||||||
|
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
|
||||||
|
)
|
||||||
|
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
||||||
|
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
|
||||||
|
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
|
||||||
|
parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
|
||||||
|
parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
|
||||||
|
parser.add_argument("--output_len", type=int, default=128, help="Output length")
|
||||||
|
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
|
||||||
|
parser.add_argument("-v", "--verbose", default=False, action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
benchmark(args)
|
|
@ -54,15 +54,6 @@ def main():
|
||||||
|
|
||||||
model.save_quantized(output_path, model_basename="llama-7b")
|
model.save_quantized(output_path, model_basename="llama-7b")
|
||||||
|
|
||||||
model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b")
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True)
|
|
||||||
input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda")
|
|
||||||
out = model.generate(**input_tokens, **generate_kwargs)
|
|
||||||
text = tokenizer.batch_decode(out)
|
|
||||||
print("out is:", text)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1,65 +1,15 @@
|
||||||
script_dir=$(cd "$(dirname "$0")" && pwd)
|
ROOT=$(realpath $(dirname $0))
|
||||||
cd "${script_dir}"
|
PY_SCRIPT=${ROOT}/benchmark_llama.py
|
||||||
|
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
|
||||||
|
|
||||||
# toy model, 2tp*2pp 1024, 128
|
mkdir -p logs
|
||||||
python ./benchmark.py \
|
|
||||||
--model="toy" \
|
|
||||||
--dtype="fp16" \
|
|
||||||
--batch_size=2 \
|
|
||||||
--seq_len=1024 \
|
|
||||||
--output_len=128 \
|
|
||||||
--mb_size=1 \
|
|
||||||
--pp_size=2 \
|
|
||||||
--tp_size=2
|
|
||||||
|
|
||||||
# 7b, fp16, 2 gpu, 1024, 128
|
# benchmark llama2-7b one single GPU
|
||||||
for BATCH_SIZE in 2 4 8 16; do
|
for bsz in 16 32 64; do
|
||||||
python ./benchmark.py \
|
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 | tee logs/${GPU}_${bsz}_256.txt
|
||||||
--model="7b" \
|
|
||||||
--dtype="fp16" \
|
|
||||||
--batch_size=${BATCH_SIZE} \
|
|
||||||
--seq_len=1024 \
|
|
||||||
--output_len=128 \
|
|
||||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
|
||||||
--pp_size=2 \
|
|
||||||
--tp_size=2
|
|
||||||
done
|
done
|
||||||
|
|
||||||
# 7b, fp16, 2 gpu, 512, 512
|
|
||||||
for BATCH_SIZE in 2 4 8 16 32; do
|
|
||||||
python ./benchmark.py \
|
|
||||||
--model="7b" \
|
|
||||||
--dtype="fp16" \
|
|
||||||
--batch_size=${BATCH_SIZE} \
|
|
||||||
--seq_len=512 \
|
|
||||||
--output_len=512 \
|
|
||||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
|
||||||
--pp_size=2 \
|
|
||||||
--tp_size=2
|
|
||||||
done
|
|
||||||
|
|
||||||
# 7b, fp16, 2 gpu, 1024, 128
|
for bsz in 4 8 16 32 64; do
|
||||||
for BATCH_SIZE in 2 4 8; do
|
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 | tee logs/${GPU}_${bsz}_1024.txt
|
||||||
python ./benchmark.py \
|
|
||||||
--model="13b" \
|
|
||||||
--dtype="fp16" \
|
|
||||||
--batch_size=${BATCH_SIZE} \
|
|
||||||
--seq_len=1024 \
|
|
||||||
--output_len=128 \
|
|
||||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
|
||||||
--pp_size=2 \
|
|
||||||
--tp_size=2
|
|
||||||
done
|
|
||||||
|
|
||||||
# 13b, fp16, 2 gpu, 512, 512
|
|
||||||
for BATCH_SIZE in 2 4 8 16; do
|
|
||||||
python ./benchmark.py \
|
|
||||||
--model="13b" \
|
|
||||||
--dtype="fp16" \
|
|
||||||
--batch_size=${BATCH_SIZE} \
|
|
||||||
--seq_len=512 \
|
|
||||||
--output_len=512 \
|
|
||||||
--mb_size=$((${BATCH_SIZE}/2)) \
|
|
||||||
--pp_size=2 \
|
|
||||||
--tp_size=2
|
|
||||||
done
|
done
|
||||||
|
|
|
@ -7,11 +7,17 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.inference import InferenceEngine
|
from colossalai.inference import InferenceEngine
|
||||||
from colossalai.testing import spawn
|
from colossalai.testing import spawn
|
||||||
|
from colossalai.utils.device import get_current_device
|
||||||
|
|
||||||
|
INPUT_TEXTS = [
|
||||||
|
"What is the longest river in the world?",
|
||||||
|
"Explain the difference between process and thread in compouter science.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def run_inference(args):
|
def run_inference(args):
|
||||||
llama_model_path = args.model_path
|
llama_model_path = args.model_path
|
||||||
llama_tokenize_path = args.tokenizer_path
|
llama_tokenize_path = args.tokenizer_path or args.model_path
|
||||||
|
|
||||||
max_input_len = args.max_input_len
|
max_input_len = args.max_input_len
|
||||||
max_output_len = args.max_output_len
|
max_output_len = args.max_output_len
|
||||||
|
@ -22,11 +28,10 @@ def run_inference(args):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left")
|
tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left")
|
||||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
if args.quant is None:
|
if args.quant is None:
|
||||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.unk_token_id)
|
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.pad_token_id)
|
||||||
model = model.half()
|
|
||||||
elif args.quant == "gptq":
|
elif args.quant == "gptq":
|
||||||
from auto_gptq import AutoGPTQForCausalLM
|
from auto_gptq import AutoGPTQForCausalLM
|
||||||
|
|
||||||
|
@ -45,18 +50,21 @@ def run_inference(args):
|
||||||
model=model,
|
model=model,
|
||||||
max_input_len=max_input_len,
|
max_input_len=max_input_len,
|
||||||
max_output_len=max_output_len,
|
max_output_len=max_output_len,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
micro_batch_size=micro_batch_size,
|
micro_batch_size=micro_batch_size,
|
||||||
quant=args.quant,
|
quant=args.quant,
|
||||||
|
dtype=args.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_tokens = {
|
inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True)
|
||||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
|
inputs = {k: v.to(get_current_device()) for k, v in inputs.items()}
|
||||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
|
outputs = engine.generate(inputs)
|
||||||
}
|
|
||||||
|
|
||||||
outputs = engine.generate(input_tokens)
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(tokenizer.batch_decode(outputs))
|
output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
for input_text, output_text in zip(INPUT_TEXTS, output_texts):
|
||||||
|
print(f"Input: {input_text}")
|
||||||
|
print(f"Output: {output_text}")
|
||||||
|
|
||||||
|
|
||||||
def run_tp_pipeline_inference(rank, world_size, port, args):
|
def run_tp_pipeline_inference(rank, world_size, port, args):
|
||||||
|
@ -67,8 +75,8 @@ def run_tp_pipeline_inference(rank, world_size, port, args):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True)
|
parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True)
|
||||||
parser.add_argument("--tokenizer_path", type=str, help="Tokenizer path", required=True)
|
parser.add_argument("-i", "--input", default="What is the longest river in the world?")
|
||||||
|
parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-q",
|
"-q",
|
||||||
"--quant",
|
"--quant",
|
||||||
|
@ -78,12 +86,13 @@ if __name__ == "__main__":
|
||||||
help="quantization type: 'gptq' or 'smoothquant'",
|
help="quantization type: 'gptq' or 'smoothquant'",
|
||||||
)
|
)
|
||||||
parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name")
|
parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name")
|
||||||
parser.add_argument("-tp", "--tp_size", type=int, default=2, help="Tensor parallel size")
|
parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||||
parser.add_argument("-pp", "--pp_size", type=int, default=2, help="Pipeline parallel size")
|
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
|
||||||
parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size")
|
parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size")
|
||||||
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
parser.add_argument("--max_input_len", type=int, default=2048, help="Maximum input length")
|
||||||
parser.add_argument("--max_output_len", type=int, default=16, help="Maximum output length")
|
parser.add_argument("--max_output_len", type=int, default=64, help="Maximum output length")
|
||||||
parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size")
|
parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size")
|
||||||
|
parser.add_argument("--dtype", default="fp16", type=str)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args)
|
spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args)
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
transformers==4.34.0
|
transformers==4.34.0
|
||||||
packaging
|
|
||||||
ninja
|
|
||||||
auto-gptq==0.5.0
|
auto-gptq==0.5.0
|
||||||
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
|
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
|
||||||
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
|
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
|
||||||
|
|
Loading…
Reference in New Issue