2024-07-08 08:02:07 +00:00
from typing import List , Tuple , Type , Union
2023-12-01 09:02:44 +00:00
2024-01-10 02:38:53 +00:00
import numpy as np
2024-07-08 08:02:07 +00:00
import PIL . Image
2023-12-18 02:40:47 +00:00
import torch . nn as nn
2024-07-08 08:02:07 +00:00
from diffusers import DiffusionPipeline
from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast
2024-04-18 08:56:46 +00:00
2024-07-08 08:02:07 +00:00
from colossalai . inference . config import InferenceConfig
from colossalai . inference . utils import ModelType , get_model_type
2023-12-18 02:40:47 +00:00
from colossalai . shardformer . policies . base_policy import Policy
2024-02-02 06:31:10 +00:00
__all__ = [ " InferenceEngine " ]
2023-12-07 06:34:01 +00:00
class InferenceEngine :
2023-12-18 02:40:47 +00:00
"""
InferenceEngine which manages the inference process . .
2023-12-01 09:02:44 +00:00
Args :
2024-07-08 08:02:07 +00:00
model_or_path ( nn . Module or DiffusionPipeline or str ) : Path or nn . Module or DiffusionPipeline of this model .
2024-02-07 09:55:48 +00:00
tokenizer Optional [ ( Union [ PreTrainedTokenizer , PreTrainedTokenizerFast ] ) ] : Path of the tokenizer to use .
2023-12-18 02:40:47 +00:00
inference_config ( Optional [ InferenceConfig ] , optional ) : Store the configuration information related to inference .
2023-12-07 06:34:01 +00:00
verbose ( bool ) : Determine whether or not to log the generation process .
2023-12-18 02:40:47 +00:00
model_policy ( " Policy " ) : the policy to shardformer model . It will be determined by the model type if not provided .
2023-12-01 09:02:44 +00:00
"""
def __init__ (
self ,
2024-07-08 08:02:07 +00:00
model_or_path : Union [ nn . Module , str , DiffusionPipeline ] ,
tokenizer : Union [ PreTrainedTokenizer , PreTrainedTokenizerFast ] = None ,
inference_config : InferenceConfig = None ,
2023-12-07 06:34:01 +00:00
verbose : bool = False ,
2024-05-20 14:49:18 +00:00
model_policy : Union [ Policy , Type [ Policy ] ] = None ,
2023-12-01 09:02:44 +00:00
) - > None :
2024-07-08 08:02:07 +00:00
self . __dict__ [ " _initialized " ] = False # use __dict__ directly to avoid calling __setattr__
self . model_type = get_model_type ( model_or_path = model_or_path )
self . engine = None
if self . model_type == ModelType . LLM :
from . llm_engine import LLMEngine
self . engine = LLMEngine (
model_or_path = model_or_path ,
tokenizer = tokenizer ,
inference_config = inference_config ,
verbose = verbose ,
model_policy = model_policy ,
2024-04-18 08:56:46 +00:00
)
2024-07-08 08:02:07 +00:00
elif self . model_type == ModelType . DIFFUSION_MODEL :
from . diffusion_engine import DiffusionEngine
self . engine = DiffusionEngine (
model_or_path = model_or_path ,
inference_config = inference_config ,
verbose = verbose ,
model_policy = model_policy ,
2024-04-18 08:56:46 +00:00
)
2024-07-08 08:02:07 +00:00
elif self . model_type == ModelType . UNKNOWN :
self . logger . error ( f " Model Type either Difffusion or LLM! " )
2024-02-07 09:55:48 +00:00
2024-07-08 08:02:07 +00:00
self . _initialized = True
self . _verify_args ( )
2024-03-08 06:19:35 +00:00
2024-03-11 01:51:42 +00:00
def _verify_args ( self ) - > None :
""" Verify the input args """
2024-07-08 08:02:07 +00:00
assert self . engine is not None , " Please init Engine first "
assert self . _initialized , " Engine must be initialized "
2024-03-11 01:51:42 +00:00
2023-12-18 02:40:47 +00:00
def generate (
self ,
2024-03-01 06:47:36 +00:00
request_ids : Union [ List [ int ] , int ] = None ,
prompts : Union [ List [ str ] , str ] = None ,
2024-07-08 08:02:07 +00:00
* args ,
* * kwargs ,
) - > Union [ List [ Union [ str , List [ PIL . Image . Image ] , np . ndarray ] ] , Tuple [ List [ str ] , List [ List [ int ] ] ] ] :
2023-12-01 09:02:44 +00:00
"""
2023-12-18 02:40:47 +00:00
Executing the inference step .
Args :
2024-02-23 02:51:35 +00:00
request_ids ( List [ int ] , optional ) : The request ID . Defaults to None .
2024-06-03 01:51:21 +00:00
prompts ( Union [ List [ str ] , optional ) : Input prompts . Defaults to None .
2024-02-07 09:11:43 +00:00
"""
2024-07-08 08:02:07 +00:00
assert self . engine is not None , " Please init Engine first "
return self . engine . generate ( request_ids = request_ids , prompts = prompts , * args , * * kwargs )
2024-02-07 09:11:43 +00:00
2023-12-18 02:40:47 +00:00
def add_request (
self ,
2024-03-01 06:47:36 +00:00
request_ids : Union [ List [ int ] , int ] = None ,
2024-05-15 07:47:31 +00:00
prompts : Union [ List [ str ] , str ] = None ,
2024-07-08 08:02:07 +00:00
* args ,
2024-04-23 05:09:55 +00:00
* * kwargs ,
2023-12-18 02:40:47 +00:00
) - > None :
2023-12-01 09:02:44 +00:00
"""
2023-12-18 02:40:47 +00:00
Add requests .
Args :
2024-02-23 02:51:35 +00:00
request_ids ( List [ int ] , optional ) : The request ID . Defaults to None .
2023-12-18 02:40:47 +00:00
prompts ( Union [ List [ str ] , optional ) : Input prompts . Defaults to None .
prompts_token_ids ( List [ List [ int ] ] , optional ) : token ids of input prompts . Defaults to None .
2024-07-08 08:02:07 +00:00
kwargs : for LLM , it could be max_length , max_new_tokens , etc
for diffusion , it could be prompt_2 , prompt_3 , num_images_per_prompt , do_classifier_free_guidance , negative_prompt , negative_prompt_2 , negative_prompt_3 , prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds , clip_skip , which aligns with diffusers
2023-12-01 09:02:44 +00:00
"""
2024-07-08 08:02:07 +00:00
assert self . engine is not None , " Please init Engine first "
self . engine . add_request ( request_ids = request_ids , prompts = prompts , * args , * * kwargs )
2023-12-01 09:02:44 +00:00
2024-07-08 08:02:07 +00:00
def step ( self ) :
assert self . engine is not None , " Please init Engine first "
return self . engine . step ( )
2024-04-23 05:09:55 +00:00
2024-07-08 08:02:07 +00:00
def __getattr__ ( self , name ) :
2023-12-01 09:02:44 +00:00
"""
2024-07-08 08:02:07 +00:00
The Design logic of getattr , setattr :
1. Since InferenceEngine is a wrapper for DiffusionEngine / LLMEngine , we hope to invoke all the member of DiffusionEngine / LLMEngine like we just call the member of InferenceEngine .
2. When we call the __init__ of InferenceEngine , we don ' t want to setattr using self.__dict__[ " xxx " ] = xxx, we want to use origin ways like self.xxx = xxx
So we set the attribute ` _initialized ` . And after initialized , if we couldn ' t get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine)
2023-12-01 09:02:44 +00:00
"""
2024-07-08 08:02:07 +00:00
if self . __dict__ . get ( " _initialized " , False ) :
if name in self . __dict__ :
return self . __dict__ [ name ]
else :
return getattr ( self . engine , name )
2024-03-08 06:19:35 +00:00
else :
2024-07-08 08:02:07 +00:00
return self . __dict__ [ name ]
2024-03-08 06:19:35 +00:00
2024-07-08 08:02:07 +00:00
def __setattr__ ( self , name , value ) :
if self . __dict__ . get ( " _initialized " , False ) :
if name in self . __dict__ :
self . __dict__ [ name ] = value
else :
setattr ( self . engine , name , value )
else :
self . __dict__ [ name ] = value