diff --git a/colossalai/inference/engine/engine.py b/colossalai/inference/engine/engine.py index a9ffc36a8..61da5858a 100644 --- a/colossalai/inference/engine/engine.py +++ b/colossalai/inference/engine/engine.py @@ -33,13 +33,16 @@ class InferenceEngine: Args: tp_size (int): the size of tensor parallelism. pp_size (int): the size of pipeline parallelism. + dtype (str): the data type of the model, should be one of 'fp16', 'fp32', 'bf16'. model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. - model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. - micro_batch_size (int): the micro batch size. + model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. It will be determined by the model type if not provided. + micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. max_batch_size (int): the maximum batch size. max_input_len (int): the maximum input length. max_output_len (int): the maximum output length. + quant (str): the quantization method, should be one of 'smoothquant', 'gptq', None. + verbose (bool): whether to return the time cost of each step. """ diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index e1a2d38cd..72480526b 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -69,6 +69,8 @@ class GenerateSchedule(PipelineSchedule): batch = tree_map(partial(to_device, device=device), batch) self.batch = batch self.batch_size = get_batch_size(batch) + if self.stage_manager.num_stages == 1: + self.microbatch_size = self.batch_size self.microbatch_offset = 0 assert ( self.batch_size % self.microbatch_size == 0 diff --git a/examples/inference/_utils.py b/examples/inference/_utils.py deleted file mode 100644 index 67d897836..000000000 --- a/examples/inference/_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -def print_perf_stats(latency_set, config, bs, warmup=3): - # trim warmup queries - latency_set = list(latency_set) - latency_set = latency_set[warmup:] - count = len(latency_set) - - if count > 0: - latency_set.sort() - avg = sum(latency_set) / count - num_layers = ( - getattr(config, "num_layers") if hasattr(config, "num_layers") else getattr(config, "num_hidden_layers") - ) - num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 # float16 - - print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) - print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) - print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) diff --git a/examples/inference/benchmark.py b/examples/inference/benchmark.py deleted file mode 100644 index a20983fd1..000000000 --- a/examples/inference/benchmark.py +++ /dev/null @@ -1,167 +0,0 @@ -import argparse -import os -import time - -import torch -import torch.distributed as dist -import transformers - -import colossalai -from colossalai.inference import CaiInferEngine -from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn - -GIGABYTE = 1024**3 -MEGABYTE = 1024 * 1024 - - -def data_gen(batch_size: int = 4, seq_len: int = 512): - input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32) - attention_mask = torch.ones((1, seq_len), dtype=torch.int32) - data = dict(input_ids=input_ids, attention_mask=attention_mask) - for k, v in data.items(): - if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = batch_size - data[k] = v.to("cuda").repeat(*new_shape) - return data - - -def print_details_info(timestamps, model_config, args, whole_end2end): - log_file_name = f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.output_len}_bsz{args.batch_size}_mbsz{args.mb_size}.log" - os.makedirs(os.path.dirname(log_file_name), exist_ok=True) - - if dist.get_rank() == 0: - prefill = [] - encoder = [] - end2end = [] - for timestamp in timestamps: - prefill.append(timestamp[1] - timestamp[0]) - encoder.append( - sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2) - ) - end2end.append(timestamp[-1] - timestamp[0]) - print(whole_end2end) - - with open( - log_file_name, - "w+", - ) as f: - mb_avg_end2end = sum(end2end) / len(end2end) - mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size) - whole_avg_latency = whole_end2end / (args.output_len * args.batch_size) - num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) - num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size - if args.dtype in ["fp16", "bf16"]: - num_bytes = 2 - else: - num_bytes = 4 - - f.write( - f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.output_len}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n" - ) - f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000)) - f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000)) - f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000)) - f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000)) - f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000)) - f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000)) - f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000)))) - f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12)) - f.write("----------------------------------------------------------\n") - - if torch.cuda.is_available(): - current_device = torch.cuda.current_device() - - # free memory and the total available memory in bytes - global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info() - memory_allocated = torch.cuda.memory_allocated() - max_memory_allocated = torch.cuda.max_memory_allocated() - memory_reserved = torch.cuda.memory_reserved() - max_memory_reserved = torch.cuda.max_memory_reserved() - with open( - log_file_name, - "a", - ) as f: - f.write( - f"\nCurrently using GPU: {current_device}\n" - f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" - f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n" - f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n" - f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n" - f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n" - f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" - ) - - -def benchmark_inference(args): - if args.model == "toy": - model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4)) - elif args.model == "7b": - model = transformers.LlamaForCausalLM( - transformers.LlamaConfig( - hidden_size=4096, - intermediate_size=11008, - num_attention_heads=32, - num_hidden_layers=32, - num_key_value_heads=32, - ) - ) - elif args.model == "13b": - model = transformers.LlamaForCausalLM( - transformers.LlamaConfig( - hidden_size=5120, - intermediate_size=13824, - num_attention_heads=40, - num_hidden_layers=40, - num_key_value_heads=40, - ) - ) - else: - raise NotImplementedError - - engine = CaiInferEngine( - pp_size=args.pp_size, - tp_size=args.tp_size, - dtype=args.dtype, - micro_batch_size=args.mb_size, - model=model, - verbose=True, - max_batch_size=args.mb_size, - max_input_len=args.seq_len, - max_output_len=args.output_len, - ) - data = data_gen(args.batch_size, args.seq_len) - - torch.cuda.synchronize() - whole_end2end = time.time() - output, timestamps = engine.generate(data) - torch.cuda.synchronize() - whole_end2end = time.time() - whole_end2end - - print_details_info(timestamps, model.config, args, whole_end2end) - - -def hybrid_inference(rank, world_size, port, args): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - benchmark_inference(args) - - -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def benchmark(args): - spawn(hybrid_inference, nprocs=args.tp_size * args.pp_size, args=args) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model", default="toy", help="the size of model") - parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") - parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") - parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") - parser.add_argument("--pp_size", type=int, default=2, help="pipeline size") - parser.add_argument("--tp_size", type=int, default=2, help="pipeline size") - parser.add_argument("--output_len", type=int, default=16, help="Output length") - parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log") - parser.add_argument("--dtype", type=str, default="fp16", help="data type") - args = parser.parse_args() - benchmark(args) diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py new file mode 100644 index 000000000..9a26098b3 --- /dev/null +++ b/examples/inference/benchmark_llama.py @@ -0,0 +1,168 @@ +import argparse +import time + +import torch +import torch.distributed as dist +import transformers + +import colossalai +import colossalai.utils.device as device_utils +from colossalai.inference import InferenceEngine +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.utils.device import get_current_device + +GIGABYTE = 1024**3 +MEGABYTE = 1024 * 1024 + +CONFIG_MAP = { + "toy": transformers.LlamaConfig(num_hidden_layers=4), + "llama-7b": transformers.LlamaConfig( + hidden_size=4096, + intermediate_size=11008, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=32, + max_position_embeddings=2048, + ), + "llama-13b": transformers.LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_attention_heads=40, + num_hidden_layers=40, + num_key_value_heads=40, + max_position_embeddings=2048, + ), + "llama2-7b": transformers.LlamaConfig( + hidden_size=4096, + intermediate_size=11008, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=32, + max_position_embeddings=4096, + ), + "llama2-13b": transformers.LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_attention_heads=40, + num_hidden_layers=40, + num_key_value_heads=40, + max_position_embeddings=4096, + ), +} + + +def data_gen(batch_size: int = 4, seq_len: int = 512): + input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device()) + attention_mask = torch.ones_like(input_ids) + data = dict(input_ids=input_ids, attention_mask=attention_mask) + return data + + +def print_details_info(outputs, model_config, args, whole_end2end): + msg: str = "" + + if dist.get_rank() == 0: + msg += "-------Perf Summary-------\n" + if args.verbose: + timestamps = outputs[1] + prefill = [] + encoder = [] + end2end = [] + for timestamp in timestamps: + prefill.append(timestamp[1] - timestamp[0]) + encoder.append( + sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2) + ) + end2end.append(timestamp[-1] - timestamp[0]) + + mb_avg_end2end = sum(end2end) / len(end2end) + mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size) + + msg += f"Average prefill time: {sum(prefill) / len(prefill) * 1000:.2f} ms\n" + msg += f"Average encode time: {sum(encoder) / len(encoder) * 1000:.2f} ms\n" + msg += f"Average micro batch end2end time: {mb_avg_end2end * 1000:.2f} ms\n" + msg += f"Average micro batch per token latency: {mb_avg_latency * 1000:.2f} ms\n" + + whole_avg_latency = whole_end2end / (args.output_len * args.batch_size) + num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) + num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size + if args.dtype in ["fp16", "bf16"]: + num_bytes = 2 + else: + num_bytes = 4 + + msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n" + msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n" + msg += f"Throughput: {args.output_len * args.batch_size / whole_end2end:.2f} tokens/s\n" + msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n" + + if torch.cuda.is_available(): + msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n" + msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n" + msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n" + + print(msg) + + +def benchmark_inference(args): + config = CONFIG_MAP[args.model] + model = transformers.LlamaForCausalLM(config) + if dist.get_rank() == 0: + print("Model loaded") + engine = InferenceEngine( + pp_size=args.pp_size, + tp_size=args.tp_size, + dtype=args.dtype, + micro_batch_size=args.mb_size, + model=model, + verbose=args.verbose, + max_batch_size=args.batch_size, + max_input_len=args.seq_len, + max_output_len=args.output_len, + ) + data = data_gen(args.batch_size, args.seq_len) + + N_WARMUP_STEPS = 2 + + for _ in range(N_WARMUP_STEPS): + engine.generate(data) + + torch.cuda.synchronize() + whole_end2end = time.time() + outputs = engine.generate(data) + torch.cuda.synchronize() + whole_end2end = time.time() - whole_end2end + + print_details_info(outputs, model.config, args, whole_end2end) + + +def hybrid_inference(rank, world_size, port, args): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + benchmark_inference(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def benchmark(args): + spawn(hybrid_inference, nprocs=args.tp_size * args.pp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + default="toy", + help="the size of model", + choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"], + ) + parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") + parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") + parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") + parser.add_argument("--pp_size", type=int, default=1, help="pipeline size") + parser.add_argument("--tp_size", type=int, default=1, help="pipeline size") + parser.add_argument("--output_len", type=int, default=128, help="Output length") + parser.add_argument("--dtype", type=str, default="fp16", help="data type") + parser.add_argument("-v", "--verbose", default=False, action="store_true") + args = parser.parse_args() + benchmark(args) diff --git a/examples/inference/build_smoothquant_weight.py b/examples/inference/build_smoothquant_weight.py index 0cb566886..d60ce1c1d 100644 --- a/examples/inference/build_smoothquant_weight.py +++ b/examples/inference/build_smoothquant_weight.py @@ -54,15 +54,6 @@ def main(): model.save_quantized(output_path, model_basename="llama-7b") - model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b") - model = model.cuda() - - generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True) - input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda") - out = model.generate(**input_tokens, **generate_kwargs) - text = tokenizer.batch_decode(out) - print("out is:", text) - if __name__ == "__main__": main() diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh old mode 100644 new mode 100755 index 79008c7d0..394222ea6 --- a/examples/inference/run_benchmark.sh +++ b/examples/inference/run_benchmark.sh @@ -1,65 +1,15 @@ -script_dir=$(cd "$(dirname "$0")" && pwd) -cd "${script_dir}" +ROOT=$(realpath $(dirname $0)) +PY_SCRIPT=${ROOT}/benchmark_llama.py +GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1) -# toy model, 2tp*2pp 1024, 128 -python ./benchmark.py \ - --model="toy" \ - --dtype="fp16" \ - --batch_size=2 \ - --seq_len=1024 \ - --output_len=128 \ - --mb_size=1 \ - --pp_size=2 \ - --tp_size=2 +mkdir -p logs -# 7b, fp16, 2 gpu, 1024, 128 -for BATCH_SIZE in 2 4 8 16; do - python ./benchmark.py \ - --model="7b" \ - --dtype="fp16" \ - --batch_size=${BATCH_SIZE} \ - --seq_len=1024 \ - --output_len=128 \ - --mb_size=$((${BATCH_SIZE}/2)) \ - --pp_size=2 \ - --tp_size=2 +# benchmark llama2-7b one single GPU +for bsz in 16 32 64; do + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 | tee logs/${GPU}_${bsz}_256.txt done -# 7b, fp16, 2 gpu, 512, 512 -for BATCH_SIZE in 2 4 8 16 32; do - python ./benchmark.py \ - --model="7b" \ - --dtype="fp16" \ - --batch_size=${BATCH_SIZE} \ - --seq_len=512 \ - --output_len=512 \ - --mb_size=$((${BATCH_SIZE}/2)) \ - --pp_size=2 \ - --tp_size=2 -done - -# 7b, fp16, 2 gpu, 1024, 128 -for BATCH_SIZE in 2 4 8; do - python ./benchmark.py \ - --model="13b" \ - --dtype="fp16" \ - --batch_size=${BATCH_SIZE} \ - --seq_len=1024 \ - --output_len=128 \ - --mb_size=$((${BATCH_SIZE}/2)) \ - --pp_size=2 \ - --tp_size=2 -done -# 13b, fp16, 2 gpu, 512, 512 -for BATCH_SIZE in 2 4 8 16; do - python ./benchmark.py \ - --model="13b" \ - --dtype="fp16" \ - --batch_size=${BATCH_SIZE} \ - --seq_len=512 \ - --output_len=512 \ - --mb_size=$((${BATCH_SIZE}/2)) \ - --pp_size=2 \ - --tp_size=2 +for bsz in 4 8 16 32 64; do + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 | tee logs/${GPU}_${bsz}_1024.txt done diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py index 8151518fe..8f85a9363 100644 --- a/examples/inference/run_llama_inference.py +++ b/examples/inference/run_llama_inference.py @@ -7,11 +7,17 @@ from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai from colossalai.inference import InferenceEngine from colossalai.testing import spawn +from colossalai.utils.device import get_current_device + +INPUT_TEXTS = [ + "What is the longest river in the world?", + "Explain the difference between process and thread in compouter science.", +] def run_inference(args): llama_model_path = args.model_path - llama_tokenize_path = args.tokenizer_path + llama_tokenize_path = args.tokenizer_path or args.model_path max_input_len = args.max_input_len max_output_len = args.max_output_len @@ -22,11 +28,10 @@ def run_inference(args): rank = dist.get_rank() tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left") - tokenizer.pad_token_id = tokenizer.unk_token_id + tokenizer.pad_token_id = tokenizer.eos_token_id if args.quant is None: - model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.unk_token_id) - model = model.half() + model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.pad_token_id) elif args.quant == "gptq": from auto_gptq import AutoGPTQForCausalLM @@ -45,18 +50,21 @@ def run_inference(args): model=model, max_input_len=max_input_len, max_output_len=max_output_len, + max_batch_size=max_batch_size, micro_batch_size=micro_batch_size, quant=args.quant, + dtype=args.dtype, ) - input_tokens = { - "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), - "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), - } + inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True) + inputs = {k: v.to(get_current_device()) for k, v in inputs.items()} + outputs = engine.generate(inputs) - outputs = engine.generate(input_tokens) if rank == 0: - print(tokenizer.batch_decode(outputs)) + output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + for input_text, output_text in zip(INPUT_TEXTS, output_texts): + print(f"Input: {input_text}") + print(f"Output: {output_text}") def run_tp_pipeline_inference(rank, world_size, port, args): @@ -67,8 +75,8 @@ def run_tp_pipeline_inference(rank, world_size, port, args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True) - parser.add_argument("--tokenizer_path", type=str, help="Tokenizer path", required=True) - + parser.add_argument("-i", "--input", default="What is the longest river in the world?") + parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None) parser.add_argument( "-q", "--quant", @@ -78,12 +86,13 @@ if __name__ == "__main__": help="quantization type: 'gptq' or 'smoothquant'", ) parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name") - parser.add_argument("-tp", "--tp_size", type=int, default=2, help="Tensor parallel size") - parser.add_argument("-pp", "--pp_size", type=int, default=2, help="Pipeline parallel size") + parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size") - parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length") - parser.add_argument("--max_output_len", type=int, default=16, help="Maximum output length") + parser.add_argument("--max_input_len", type=int, default=2048, help="Maximum input length") + parser.add_argument("--max_output_len", type=int, default=64, help="Maximum output length") parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size") + parser.add_argument("--dtype", default="fp16", type=str) args = parser.parse_args() spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args) diff --git a/requirements/requirements-infer.txt b/requirements/requirements-infer.txt index 46a6b41bf..f85f9d88e 100644 --- a/requirements/requirements-infer.txt +++ b/requirements/requirements-infer.txt @@ -1,6 +1,4 @@ transformers==4.34.0 -packaging -ninja auto-gptq==0.5.0 git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8 git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9