mirror of https://github.com/hpcaitech/ColossalAI
[inference] update examples and engine (#5073)
* update examples and engine * fix choices * update examplepull/5078/head
parent
0c7d8bebd5
commit
fb103cfd6e
|
@ -1,4 +1,4 @@
|
|||
from .engine import CaiInferEngine
|
||||
from .engine import InferenceEngine
|
||||
from .engine.policies import BloomModelInferPolicy, ChatGLM2InferPolicy, LlamaModelInferPolicy
|
||||
|
||||
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
|
||||
__all__ = ["InferenceEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .engine import CaiInferEngine
|
||||
from .engine import InferenceEngine
|
||||
|
||||
__all__ = ["CaiInferEngine"]
|
||||
__all__ = ["InferenceEngine"]
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
|
@ -27,9 +26,9 @@ _supported_models = [
|
|||
]
|
||||
|
||||
|
||||
class CaiInferEngine:
|
||||
class InferenceEngine:
|
||||
"""
|
||||
CaiInferEngine is a class that handles the pipeline parallel inference.
|
||||
InferenceEngine is a class that handles the pipeline parallel inference.
|
||||
|
||||
Args:
|
||||
tp_size (int): the size of tensor parallelism.
|
||||
|
@ -42,27 +41,6 @@ class CaiInferEngine:
|
|||
max_input_len (int): the maximum input length.
|
||||
max_output_len (int): the maximum output length.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from colossalai.inference import InferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
|
||||
# assume the model is infered with 2 pipeline stages
|
||||
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
|
||||
|
||||
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
|
||||
data = tokenizer(input, return_tensors='pt')
|
||||
output = inferengine.inference([data.to('cuda').data])
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -148,7 +126,7 @@ class CaiInferEngine:
|
|||
if quant == "gptq":
|
||||
self.gptq_manager.post_init_gptq_buffer(self.model)
|
||||
|
||||
def generate(self, input_list: Union[BatchEncoding, dict]):
|
||||
def generate(self, input_list: Union[list, dict]):
|
||||
"""
|
||||
Args:
|
||||
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
|
||||
|
@ -157,11 +135,7 @@ class CaiInferEngine:
|
|||
out (list): a list of output data, each element is a list of token.
|
||||
timestamp (float): the time cost of the inference, only return when verbose is `True`.
|
||||
"""
|
||||
assert isinstance(
|
||||
input_list, (BatchEncoding, dict)
|
||||
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
|
||||
if isinstance(input_list, BatchEncoding):
|
||||
input_list = input_list.data
|
||||
|
||||
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
|
||||
if self.verbose:
|
||||
return out, timestamp
|
||||
|
|
|
@ -29,7 +29,7 @@ def parse_args():
|
|||
type=str,
|
||||
help="location of the calibration dataset",
|
||||
)
|
||||
parser.add_argument("--num-samples", type=int, default=512)
|
||||
parser.add_argument("--num-samples", type=int, default=10)
|
||||
parser.add_argument("--seq-len", type=int, default=512)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
@ -41,13 +41,12 @@ def main():
|
|||
model_path = args.model_name
|
||||
dataset_path = args.dataset_path
|
||||
output_path = args.output_path
|
||||
num_samples = 10
|
||||
seq_len = 512
|
||||
num_samples = args.num_samples
|
||||
seq_len = args.seq_len
|
||||
|
||||
model, tokenizer = build_model_and_tokenizer(model_path)
|
||||
if not os.path.exists(dataset_path):
|
||||
print(f"Cannot find the dataset at {args.dataset_path}")
|
||||
raise FileNotFoundError
|
||||
raise FileNotFoundError(f"Cannot find the dataset at {args.dataset_path}")
|
||||
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
||||
|
||||
model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len)
|
|
@ -1,72 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import spawn
|
||||
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def run_llama_inference(args):
|
||||
quantized_model_dir = args.quantized_path
|
||||
max_batch_size = args.max_batch_size
|
||||
max_input_len = args.max_input_len
|
||||
max_output_len = args.max_output_len
|
||||
micro_batch_size = args.micro_batch_size
|
||||
# load quantized model to the first GPU
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
quantized_model_dir, inject_fused_attention=False, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
engine = CaiInferEngine(
|
||||
tp_size=2,
|
||||
pp_size=2,
|
||||
model=model,
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
micro_batch_size=micro_batch_size,
|
||||
quant="gptq",
|
||||
)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
inputs = data_gen()
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 16
|
||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
output = engine.generate(inputs)
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
|
||||
|
||||
|
||||
def run_gptq_infernece(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_llama_inference(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size")
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size")
|
||||
parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size")
|
||||
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
||||
parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length")
|
||||
args = parser.parse_args()
|
||||
|
||||
spawn(run_gptq_infernece, args.tp_size * args.pp_size, args=args)
|
|
@ -1,86 +0,0 @@
|
|||
import argparse
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def run_inference(args):
|
||||
llama_model_path = args.path
|
||||
max_input_len = args.max_input_len
|
||||
max_output_len = args.max_output_len
|
||||
max_batch_size = args.batch_size
|
||||
micro_batch_size = args.micro_batch_size
|
||||
tp_size = args.tp_size
|
||||
pp_size = args.pp_size
|
||||
rank = dist.get_rank()
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
|
||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
||||
model = model.half()
|
||||
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
|
||||
)
|
||||
)
|
||||
|
||||
engine = CaiInferEngine(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
max_output_len=max_output_len,
|
||||
micro_batch_size=micro_batch_size,
|
||||
)
|
||||
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
|
||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
|
||||
}
|
||||
|
||||
iters = 10
|
||||
warmup = 3
|
||||
times = []
|
||||
|
||||
for i in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
outputs = engine.generate(input_tokens)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
if rank == 0:
|
||||
out_len = len(outputs[0])
|
||||
print("generation time {} s".format(str(end - start)))
|
||||
print(out_len)
|
||||
times.append((end - start) / out_len)
|
||||
if rank == 0:
|
||||
times = times[warmup:]
|
||||
latency = sum(times) / len(times)
|
||||
print("total process latency is : " + str(latency) + " s")
|
||||
print("total throughput is : " + str(1 / latency * max_batch_size))
|
||||
|
||||
|
||||
def run_tp_pipeline_inference(rank, world_size, port, args):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_inference(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("-pp", "--pp_size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=64, help="Maximum batch size")
|
||||
parser.add_argument("--max_input_len", type=int, default=512, help="Maximum input length")
|
||||
parser.add_argument("--max_output_len", type=int, default=256, help="Maximum output length")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Micro batch size")
|
||||
|
||||
args = parser.parse_args()
|
||||
spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args)
|
|
@ -1,69 +0,0 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine
|
||||
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_llama_inference(args):
|
||||
quantized_model_dir = args.quantized_path
|
||||
max_batch_size = args.max_batch_size
|
||||
max_input_len = args.max_input_len
|
||||
max_output_len = args.max_output_len
|
||||
micro_batch_size = args.micro_batch_size
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
inputs = data_gen()
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 16
|
||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
model = SmoothLlamaForCausalLM.from_quantized(quantized_model_dir, model_basename="llama-7b")
|
||||
model = model.cuda()
|
||||
|
||||
engine = CaiInferEngine(
|
||||
tp_size=2,
|
||||
pp_size=2,
|
||||
model=model,
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
micro_batch_size=micro_batch_size,
|
||||
quant="smoothquant",
|
||||
)
|
||||
|
||||
output = engine.generate(inputs)
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == 32, f"{len(output)}, {32}"
|
||||
|
||||
|
||||
def run_smoothquant_inference(rank, world_size, port, args):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_llama_inference(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size")
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size")
|
||||
parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size")
|
||||
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
||||
parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length")
|
||||
|
||||
args = parser.parse_args()
|
||||
spawn(run_smoothquant_inference, args.tp_size * args.pp_size, args=args)
|
|
@ -0,0 +1,89 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def run_inference(args):
|
||||
llama_model_path = args.model_path
|
||||
llama_tokenize_path = args.tokenizer_path
|
||||
|
||||
max_input_len = args.max_input_len
|
||||
max_output_len = args.max_output_len
|
||||
max_batch_size = args.batch_size
|
||||
micro_batch_size = args.micro_batch_size
|
||||
tp_size = args.tp_size
|
||||
pp_size = args.pp_size
|
||||
rank = dist.get_rank()
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left")
|
||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||
|
||||
if args.quant is None:
|
||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.unk_token_id)
|
||||
model = model.half()
|
||||
elif args.quant == "gptq":
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
llama_model_path, inject_fused_attention=False, device=torch.cuda.current_device()
|
||||
)
|
||||
elif args.quant == "smoothquant":
|
||||
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
|
||||
|
||||
model = SmoothLlamaForCausalLM.from_quantized(llama_model_path, model_basename=args.smoothquant_base_name)
|
||||
model = model.cuda()
|
||||
|
||||
engine = InferenceEngine(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
micro_batch_size=micro_batch_size,
|
||||
quant=args.quant,
|
||||
)
|
||||
|
||||
input_tokens = {
|
||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
|
||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
|
||||
}
|
||||
|
||||
outputs = engine.generate(input_tokens)
|
||||
if rank == 0:
|
||||
print(tokenizer.batch_decode(outputs))
|
||||
|
||||
|
||||
def run_tp_pipeline_inference(rank, world_size, port, args):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_inference(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("--tokenizer_path", type=str, help="Tokenizer path", required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quant",
|
||||
type=str,
|
||||
choices=["gptq", "smoothquant"],
|
||||
default=None,
|
||||
help="quantization type: 'gptq' or 'smoothquant'",
|
||||
)
|
||||
parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name")
|
||||
parser.add_argument("-tp", "--tp_size", type=int, default=2, help="Tensor parallel size")
|
||||
parser.add_argument("-pp", "--pp_size", type=int, default=2, help="Pipeline parallel size")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size")
|
||||
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
|
||||
parser.add_argument("--max_output_len", type=int, default=16, help="Maximum output length")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size")
|
||||
|
||||
args = parser.parse_args()
|
||||
spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args)
|
|
@ -3,5 +3,4 @@ packaging
|
|||
ninja
|
||||
auto-gptq==0.5.0
|
||||
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
|
||||
git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
||||
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
|
||||
|
|
|
@ -7,7 +7,7 @@ import transformers
|
|||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
@ -36,7 +36,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
|||
transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4)
|
||||
)
|
||||
|
||||
engine = CaiInferEngine(
|
||||
engine = InferenceEngine(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.distributed as dist
|
|||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
@ -44,7 +44,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
|||
)
|
||||
model = ChatGLMForConditionalGeneration(chatglm_config)
|
||||
|
||||
engine = CaiInferEngine(
|
||||
engine = InferenceEngine(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
|
|
|
@ -7,7 +7,7 @@ import transformers
|
|||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference import CaiInferEngine
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
@ -41,7 +41,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
|||
)
|
||||
)
|
||||
|
||||
engine = CaiInferEngine(
|
||||
engine = InferenceEngine(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
|
|
Loading…
Reference in New Issue