mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txtpull/5258/head
yuehuayingxueluo
11 months ago
committed by
FrankLeeeee
13 changed files with 553 additions and 170 deletions
@ -1,65 +1,232 @@
|
||||
from logging import Logger |
||||
from typing import Optional |
||||
from itertools import count |
||||
from typing import List, Optional, Union |
||||
|
||||
from transformers import AutoConfig |
||||
import torch |
||||
import torch.nn as nn |
||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast |
||||
|
||||
from colossalai.cluster import ProcessGroupMesh |
||||
from colossalai.inference.config import InferenceConfig |
||||
from colossalai.inference.modeling.policy import model_policy_map |
||||
from colossalai.inference.struct import Sequence |
||||
from colossalai.logging import get_dist_logger |
||||
from colossalai.pipeline.stage_manager import PipelineStageManager |
||||
from colossalai.shardformer import ShardConfig, ShardFormer |
||||
from colossalai.shardformer.policies.base_policy import Policy |
||||
|
||||
from .request_handler import RequestHandler |
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1 |
||||
|
||||
_supported_models = [ |
||||
"LlamaForCausalLM", |
||||
] |
||||
|
||||
|
||||
class InferenceEngine: |
||||
""" |
||||
InferenceEngine is the core component for Inference. |
||||
|
||||
It is responsible for launch the inference process, including: |
||||
- Initialize model and distributed training environment(if needed) |
||||
- Launch request_handler and corresponding kv cache manager |
||||
- Receive requests and generate texts. |
||||
- Log the generation process |
||||
""" |
||||
InferenceEngine which manages the inference process.. |
||||
|
||||
Args: |
||||
tokenizer: Path of the tokenizer to use. |
||||
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs. |
||||
model (nn.Module): Path or nn.Module of this model. |
||||
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use. |
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. |
||||
verbose (bool): Determine whether or not to log the generation process. |
||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
tokenizer: str = None, |
||||
model: nn.Module, |
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], |
||||
inference_config: Optional["InferenceConfig"] = None, |
||||
verbose: bool = False, |
||||
model_policy: Policy = None, |
||||
) -> None: |
||||
assert inference_config, "Please provide inference_config." |
||||
|
||||
self._init_model() |
||||
# cache_config may need to be modified later. |
||||
# self.request_handler = RequestHandler(cache_config) |
||||
self.tokenizer = tokenizer |
||||
self.hf_model_config = AutoConfig.from_pretrained( |
||||
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision |
||||
self.inference_config = inference_config |
||||
self.model_config = model.config |
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32: |
||||
self.dtype = torch.float32 |
||||
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16: |
||||
self.dtype = torch.float16 |
||||
model.half() |
||||
else: |
||||
self.dtype = torch.bfloat16 |
||||
model.to(torch.bfloat16) |
||||
|
||||
if model_policy is None: |
||||
model_policy = model_policy_map[self.model_config.model_type]() |
||||
|
||||
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size) |
||||
|
||||
self.model = self._shardformer( |
||||
model, |
||||
model_policy, |
||||
None, |
||||
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None, |
||||
) |
||||
|
||||
self.verbose = verbose |
||||
if verbose: |
||||
self.logger = Logger() |
||||
self.logger = get_dist_logger(__name__) |
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config) |
||||
self.counter = count() |
||||
|
||||
def _verify_config(self) -> None: |
||||
""" |
||||
Verify the input config |
||||
""" |
||||
if not isinstance(self.model, nn.Module): |
||||
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}") |
||||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance( |
||||
self.tokenizer, PreTrainedTokenizer |
||||
): |
||||
raise TypeError( |
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}" |
||||
) |
||||
assert ( |
||||
self.model.__class__.__name__ in _supported_models |
||||
), f"Model {self.model.__class__.__name__} is not supported." |
||||
|
||||
def _shardformer( |
||||
self, |
||||
model: nn.Module, |
||||
model_policy: Policy, |
||||
stage_manager: PipelineStageManager = None, |
||||
tp_group: ProcessGroupMesh = None, |
||||
) -> nn.Module: |
||||
""" |
||||
Initialize ShardConfig and replace the model with shardformer. |
||||
|
||||
Args: |
||||
model (nn.Module): Path or nn.Module of this model. |
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type. |
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. |
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. |
||||
|
||||
Returns: |
||||
nn.Module: _description_ |
||||
""" |
||||
shardconfig = ShardConfig( |
||||
tensor_parallel_process_group=tp_group, |
||||
pipeline_stage_manager=stage_manager, |
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1), |
||||
enable_fused_normalization=False, |
||||
enable_all_optimization=False, |
||||
enable_flash_attention=False, |
||||
enable_jit_fused=False, |
||||
enable_sequence_parallelism=False, |
||||
extra_kwargs={"quant": self.inference_config.quant_mode}, |
||||
) |
||||
shardformer = ShardFormer(shard_config=shardconfig) |
||||
shard_model, _ = shardformer.optimize(model, model_policy) |
||||
return shard_model.cuda() |
||||
|
||||
def _init_model(self): |
||||
def generate( |
||||
self, |
||||
generation_config: GenerationConfig = None, |
||||
) -> List[str]: |
||||
""" |
||||
Initialize model and distributed training environment(if needed). |
||||
May need to provide two different initialization methods: |
||||
1. 用户自定义(from local path) |
||||
2. 从checkpoint加载(hugging face) |
||||
Executing the inference step. |
||||
|
||||
Args: |
||||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None. |
||||
|
||||
Returns: |
||||
List[str]: Inference result returned by one generation. |
||||
""" |
||||
|
||||
def _verify_config(self): |
||||
self.generation_config = generation_config |
||||
|
||||
output_list = [] |
||||
|
||||
while self.request_handler.check_unfinished_seqs(): |
||||
output_list += self.step() |
||||
|
||||
return output_list |
||||
|
||||
def add_request( |
||||
self, |
||||
requests_id: List[int] = None, |
||||
prompts: List[str] = None, |
||||
prompts_token_ids: List[int] = None, |
||||
) -> None: |
||||
""" |
||||
Verify the configuration to avoid potential bugs. |
||||
Add requests. |
||||
|
||||
Args: |
||||
requests_id (List[int], optional): The request ID. Defaults to None. |
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None. |
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. |
||||
""" |
||||
|
||||
def generate(self): |
||||
pass |
||||
block_size = self.inference_config.block_size |
||||
|
||||
def step(self): |
||||
if prompts_token_ids is None: |
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." |
||||
prompts_token_ids = [] |
||||
for prompt in prompts: |
||||
prompts_token_ids.append(self.tokenizer.encode(prompt)) |
||||
|
||||
prompts_num = len(prompts_token_ids) |
||||
|
||||
for i in range(prompts_num): |
||||
if requests_id: |
||||
request_id = requests_id[i] |
||||
else: |
||||
request_id = next(self.counter) |
||||
if prompts == None: |
||||
prompt = None |
||||
else: |
||||
prompt = prompts[i] |
||||
sequence = Sequence( |
||||
request_id, |
||||
prompt, |
||||
prompts_token_ids[i], |
||||
block_size, |
||||
None, |
||||
None, |
||||
self.tokenizer.eos_token_id, |
||||
self.inference_config.max_output_len, |
||||
) |
||||
self.request_handler.add_sequence(sequence) |
||||
|
||||
def step(self) -> List[str]: |
||||
""" |
||||
In each step, do the follows: |
||||
1. Run request_handler to update the kv cache and running input_ids |
||||
1. Run RequestHandler.schedule() and get the batch used for inference. |
||||
2. Run model to generate the next token |
||||
3. Check whether there is finied request and decode |
||||
3. Update waiting list and running list in RequestHandler and get finished sequences. |
||||
4. Decode and return finished sequences. |
||||
|
||||
Returns: |
||||
List[str]: Decoded finished sequences generated by one step. |
||||
""" |
||||
|
||||
if self.verbose: |
||||
self.logger.info("Running generation step") |
||||
|
||||
output_list = [] |
||||
self.request_handler.schedule() |
||||
|
||||
# Uncomment if the development of RequestHandler is completed. |
||||
# logits = self.model(batch) |
||||
# self.request_handler.search_tokens(logits, self.generation_config) |
||||
|
||||
finished_sequences = self.request_handler.update() |
||||
|
||||
# Decode completed sentences. |
||||
for seq in finished_sequences: |
||||
if seq.prompt: |
||||
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True) |
||||
output_list.append(seq.prompt + output_str) |
||||
else: |
||||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True) |
||||
output_list.append(output_str) |
||||
|
||||
return output_list |
||||
|
@ -0,0 +1,7 @@
|
||||
from .llama import LlamaModelInferPolicy |
||||
|
||||
model_policy_map = { |
||||
"llama": LlamaModelInferPolicy, |
||||
} |
||||
|
||||
__all__ = ["LlamaModelInferPolicy", "model_polic_map"] |
@ -0,0 +1,7 @@
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy |
||||
|
||||
|
||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy): |
||||
# The code here just for test and will be modified later. |
||||
def __init__(self) -> None: |
||||
super().__init__() |
@ -1,4 +1,5 @@
|
||||
ordered_set |
||||
transformers==4.34.0 |
||||
auto-gptq==0.5.0 |
||||
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8 |
||||
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9 |
||||
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9 |
@ -0,0 +1,44 @@
|
||||
import pytest |
||||
import transformers |
||||
from transformers import AutoTokenizer |
||||
|
||||
import colossalai |
||||
from colossalai.inference.config import InferenceConfig |
||||
from colossalai.inference.core.engine import InferenceEngine |
||||
from colossalai.testing import spawn |
||||
|
||||
|
||||
def check_inference_engine(): |
||||
model = transformers.LlamaForCausalLM( |
||||
transformers.LlamaConfig( |
||||
vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 |
||||
) |
||||
) |
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") |
||||
inference_config = InferenceConfig() |
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) |
||||
|
||||
inputs = [ |
||||
"介绍一下北京", |
||||
"介绍一下武汉", |
||||
] |
||||
|
||||
inference_engine.add_request(prompts=inputs) |
||||
outputs = inference_engine.generate(None) |
||||
|
||||
for s1, s2 in zip(inputs, outputs): |
||||
assert s1 == s2 |
||||
|
||||
|
||||
def run_dist(rank, world_size, port): |
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") |
||||
check_inference_engine() |
||||
|
||||
|
||||
@pytest.mark.dist |
||||
def test_inference_engine(): |
||||
spawn(run_dist, 1) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_inference_engine() |
Loading…
Reference in new issue