From cba20525a81565fc86e13b78973ffa8210a05cd3 Mon Sep 17 00:00:00 2001 From: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Date: Mon, 8 Jul 2024 16:02:07 +0800 Subject: [PATCH] [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838) * Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha support --- colossalai/inference/config.py | 48 +- colossalai/inference/core/base_engine.py | 90 ++ colossalai/inference/core/diffusion_engine.py | 200 +++++ colossalai/inference/core/engine.py | 800 ++---------------- colossalai/inference/core/llm_engine.py | 758 +++++++++++++++++ colossalai/inference/core/request_handler.py | 51 +- .../inference/modeling/models/diffusion.py | 54 ++ .../inference/modeling/models/pixart_alpha.py | 220 +++++ .../modeling/models/stablediffusion3.py | 178 ++++ .../inference/modeling/policy/__init__.py | 6 + .../inference/modeling/policy/pixart_alpha.py | 34 + .../modeling/policy/stablediffusion3.py | 34 + colossalai/inference/struct.py | 12 + colossalai/inference/utils.py | 39 +- .../stable_diffusion/sd3_generation.py | 75 ++ requirements/requirements.txt | 1 + 16 files changed, 1860 insertions(+), 740 deletions(-) create mode 100644 colossalai/inference/core/base_engine.py create mode 100644 colossalai/inference/core/diffusion_engine.py create mode 100644 colossalai/inference/core/llm_engine.py create mode 100644 colossalai/inference/modeling/models/diffusion.py create mode 100644 colossalai/inference/modeling/models/pixart_alpha.py create mode 100644 colossalai/inference/modeling/models/stablediffusion3.py create mode 100644 colossalai/inference/modeling/policy/pixart_alpha.py create mode 100644 colossalai/inference/modeling/policy/stablediffusion3.py create mode 100644 examples/inference/stable_diffusion/sd3_generation.py diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index e114e8a61..1beb86874 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified import logging from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers.generation import GenerationConfig @@ -396,3 +396,49 @@ class ModelShardInferenceConfig: use_cuda_kernel: bool = False use_spec_dec: bool = False use_flash_attn: bool = False + + +@dataclass +class DiffusionGenerationConfig: + """ + Param for diffusion model forward + """ + + prompt_2: Optional[Union[str, List[str]]] = None + prompt_3: Optional[Union[str, List[str]]] = None + height: Optional[int] = None + width: Optional[int] = None + num_inference_steps: int = None + timesteps: List[int] = None + guidance_scale: float = None + negative_prompt: Optional[Union[str, List[str]]] = ( + None # NOTE(@lry89757) in pixart default to "", in sd3 default to None + ) + negative_prompt_2: Optional[Union[str, List[str]]] = None + negative_prompt_3: Optional[Union[str, List[str]]] = None + num_images_per_prompt: Optional[int] = None + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None + latents: Optional[torch.FloatTensor] = None + prompt_embeds: Optional[torch.FloatTensor] = None + negative_prompt_embeds: Optional[torch.FloatTensor] = None + pooled_prompt_embeds: Optional[torch.FloatTensor] = None + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None + output_type: Optional[str] = None # "pil" + return_dict: bool = None + joint_attention_kwargs: Optional[Dict[str, Any]] = None + clip_skip: Optional[int] = None + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None + callback_on_step_end_tensor_inputs: List[str] = None + + def to_dict(self) -> Dict[str, Any]: + # NOTE(@lry89757) Only return the dict that not the default value None + result = {} + for field in fields(self): + value = getattr(self, field.name) + if value is not None: + result[field.name] = value + return result + + @classmethod + def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig": + return cls(**kwargs) diff --git a/colossalai/inference/core/base_engine.py b/colossalai/inference/core/base_engine.py new file mode 100644 index 000000000..392dd2990 --- /dev/null +++ b/colossalai/inference/core/base_engine.py @@ -0,0 +1,90 @@ +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.config import ModelShardInferenceConfig +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy + + +class BaseEngine(ABC): + @abstractmethod + def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None): + pass + + @abstractmethod + def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None): + """ + Init Model for Engine + """ + + @abstractmethod + def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs): + """ + Generate ouptput for coming requests + """ + + @abstractmethod + def add_request(self, prompts, request_ids=None, **kwargs): + """ + Add new request to Engine + """ + + @abstractmethod + def step(self): + """ + Perform one new step forward + """ + + @abstractmethod + def _verify_args(self): + """ + Verify the parameters and members of class + """ + + @torch.inference_mode() + def capture_model(self): + """ + Use cuda graph to capture model + """ + return NotImplementedError("This method should be implemented by subclasses") + + def _shardformer( + self, + model: nn.Module, + model_policy: Policy, + model_shard_infer_config: ModelShardInferenceConfig = None, + stage_manager: PipelineStageManager = None, + tp_group: ProcessGroupMesh = None, + **kwargs, + ) -> nn.Module: + """ + Initialize ShardConfig and replace the model with shardformer. + + Args: + model (nn.Module): Path or nn.Module of this model. + model_policy (Policy): The policy to shardformer model which is determined by the model type. + stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. + tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. + + Returns: + nn.Module: The model optimized by Shardformer. + """ + + shardconfig = ShardConfig( + tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=(self.inference_config.tp_size > 1), + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs}, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model diff --git a/colossalai/inference/core/diffusion_engine.py b/colossalai/inference/core/diffusion_engine.py new file mode 100644 index 000000000..75b9889bf --- /dev/null +++ b/colossalai/inference/core/diffusion_engine.py @@ -0,0 +1,200 @@ +from itertools import count +from typing import List, Tuple, Type, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn as nn +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from torch import distributed as dist + +from colossalai.accelerator import get_accelerator +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig +from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.policy import model_policy_map +from colossalai.inference.struct import DiffusionSequence +from colossalai.inference.utils import get_model_size, get_model_type +from colossalai.logging import get_dist_logger +from colossalai.shardformer.policies.base_policy import Policy + +from .base_engine import BaseEngine +from .request_handler import NaiveRequestHandler + +PP_AXIS, TP_AXIS = 0, 1 + + +class DiffusionEngine(BaseEngine): + def __init__( + self, + model_or_path: DiffusionPipeline | str, + inference_config: InferenceConfig = None, + verbose: bool = False, + model_policy: Policy | type[Policy] = None, + ) -> None: + self.inference_config = inference_config + self.dtype = inference_config.dtype + self.high_precision = inference_config.high_precision + + self.verbose = verbose + self.logger = get_dist_logger(__name__) + self.model_shard_infer_config = inference_config.to_model_shard_inference_config() + + self.model_type = get_model_type(model_or_path=model_or_path) + + self.init_model(model_or_path, model_policy, self.model_shard_infer_config) + + self.request_handler = NaiveRequestHandler() + + self.counter = count() + + self._verify_args() + + def _verify_args(self) -> None: + assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe" + + def init_model( + self, + model_or_path: Union[str, nn.Module, DiffusionPipeline], + model_policy: Union[Policy, Type[Policy]] = None, + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + """ + Shard model or/and Load weight + + Args: + model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. + model_policy (Policy): the policy to replace the model. + model_inference_config: the configuration for modeling initialization when inference. + model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. + """ + if isinstance(model_or_path, str): + model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype) + policy_map_key = model.__class__.__name__ + model = DiffusionPipe(model) + elif isinstance(model_or_path, DiffusionPipeline): + policy_map_key = model_or_path.__class__.__name__ + model = DiffusionPipe(model_or_path) + else: + self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!") + + torch.cuda.empty_cache() + init_gpu_memory = torch.cuda.mem_get_info()[0] + + self.device = get_accelerator().get_current_device() + if self.verbose: + self.logger.info(f"the device is {self.device}") + + if self.verbose: + self.logger.info( + f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + ) + + if model_policy is None: + model_policy = model_policy_map.get(policy_map_key) + + if not isinstance(model_policy, Policy): + try: + model_policy = model_policy() + except Exception as e: + raise ValueError(f"Unable to instantiate model policy: {e}") + + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" + pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) + tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + + self.model = self._shardformer( + model, + model_policy, + model_shard_infer_config, + None, + tp_group=tp_group, + ) + + self.model = model.to(self.device) + + if self.verbose: + self.logger.info( + f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + ) + + free_gpu_memory, _ = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + if self.verbose: + self.logger.info( + f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" + ) + + def generate( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + generation_config: DiffusionGenerationConfig = None, + **kwargs, + ) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]: + """ """ + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + prompts = [prompts] if isinstance(prompts, str) else prompts + request_ids = [request_ids] if isinstance(request_ids, int) else request_ids + + with torch.inference_mode(): + if prompts is not None: + self.add_request( + request_ids=request_ids, + prompts=prompts, + **gen_config_dict, + **kwargs, + ) + + output_reqs_list = [] + + # intuition: If user provide a generation config, we should replace the existing one. + if generation_config is not None: + self.generation_config = generation_config + self.generation_config_dict = gen_config_dict + + while self.request_handler.check_unfinished_reqs(): + output_reqs_list += self.step() + + return output_reqs_list + + def add_request( + self, + prompts: Union[List[str], str], + request_ids: Union[List[int], int] = None, + **kwargs, + ): + if request_ids is not None and not isinstance(request_ids, list): + request_ids = [request_ids] + + if not isinstance(prompts, list): + prompts = [prompts] + + generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs) + prompts_num = len(prompts) + for i in range(prompts_num): + if request_ids: + assert isinstance( + request_ids[0], int + ), f"The request_id type must be int, but got {type(request_ids[0])}" + assert len(request_ids) == prompts_num + request_id = request_ids[i] + else: + request_id = next(self.counter) + + seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config) + + self.request_handler.add_sequence(seq) + + def step(self) -> List[PIL.Image.Image]: + """ + In each step, do the follows: + 1. Run RequestHandler.schedule() and get the batch used for inference. + 2. run forward to get List[Image] + Returns: + List[PIL.Image.Image]: Image Generated by one step. + """ + + input = self.request_handler.schedule() + ret = self.model(prompt=input.prompt, **input.generation_config.to_dict()) + return ret diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 8f8aef65e..5c9bdc321 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,57 +1,24 @@ -import time -from itertools import count -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import List, Tuple, Type, Union import numpy as np -import torch +import PIL.Image import torch.nn as nn -from torch import distributed as dist -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - GenerationConfig, - PreTrainedTokenizer, - PreTrainedTokenizerFast, -) -from transformers.models.llama.modeling_llama import LlamaForCausalLM +from diffusers import DiffusionPipeline +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from colossalai.accelerator import get_accelerator -from colossalai.cluster import ProcessGroupMesh -from colossalai.inference.batch_bucket import BatchBucket -from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig -from colossalai.inference.graph_runner import CUDAGraphRunner -from colossalai.inference.modeling.policy import model_policy_map -from colossalai.inference.sampler import search_tokens -from colossalai.inference.spec import Drafter, GlideInput -from colossalai.inference.struct import Sequence -from colossalai.inference.utils import get_model_size, has_index_file -from colossalai.interface import ModelWrapper -from colossalai.lazy import LazyInitContext -from colossalai.logging import get_dist_logger -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.inference.config import InferenceConfig +from colossalai.inference.utils import ModelType, get_model_type from colossalai.shardformer.policies.base_policy import Policy -from .request_handler import RequestHandler - __all__ = ["InferenceEngine"] -PP_AXIS, TP_AXIS = 0, 1 - -_supported_models = { - "LlamaForCausalLM": LlamaForCausalLM, - "BaichuanForCausalLM": AutoModelForCausalLM, -} - -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] - class InferenceEngine: """ InferenceEngine which manages the inference process.. Args: - model_or_path (nn.Module or str): Path or nn.Module of this model. + model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model. tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. verbose (bool): Determine whether or not to log the generation process. @@ -60,567 +27,68 @@ class InferenceEngine: def __init__( self, - model_or_path: Union[nn.Module, str], - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - inference_config: InferenceConfig, + model_or_path: Union[nn.Module, str, DiffusionPipeline], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None, + inference_config: InferenceConfig = None, verbose: bool = False, model_policy: Union[Policy, Type[Policy]] = None, ) -> None: - self.inference_config = inference_config - self.dtype = inference_config.dtype - self.high_precision = inference_config.high_precision - - self.verbose = verbose - self.logger = get_dist_logger(__name__) - self.model_shard_infer_config = inference_config.to_model_shard_inference_config() - - self.init_model(model_or_path, model_policy, self.model_shard_infer_config) - - self.generation_config = inference_config.to_generation_config(self.model_config) - self.generation_config_dict = self.generation_config.to_dict() - - self.tokenizer = tokenizer - self.tokenizer.pad_token = self.tokenizer.eos_token - - self.request_handler = RequestHandler(self.inference_config, self.model_config) - self.k_cache, self.v_cache = self.request_handler.get_kvcache() - # DISCUSS maybe move this into batch info? - - self.counter = count() - - self.use_cuda_graph = self.inference_config.use_cuda_graph - if self.use_cuda_graph: - self.graph_runners: Dict[int, CUDAGraphRunner] = {} - self.graph_memory_pool = None # Set during graph capture. - if verbose: - self.logger.info("Colossal AI CUDA Graph Capture on") - - self.capture_model(self.k_cache, self.v_cache) - - # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` - self.use_spec_dec = self.inference_config.use_spec_dec - - self.drafter_model = None - self.drafter = None - self.use_glide = False - self.n_spec_tokens = self.inference_config.max_n_spec_tokens - - self._verify_args() - - def init_model( - self, - model_or_path: Union[nn.Module, str], - model_policy: Union[Policy, Type[Policy]] = None, - model_shard_infer_config: ModelShardInferenceConfig = None, - ): - """ - Shard model or/and Load weight - - Args: - model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. - model_policy (Policy): the policy to replace the model. - model_inference_config: the configuration for modeling initialization when inference. - model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. - """ - pretrained_path = None - if isinstance(model_or_path, str): - import colossalai.interface.pretrained as pretrained_utils - - try: - hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) - arch = getattr(hf_config, "architectures")[0] - if arch in _supported_models.keys(): - if arch is "BaichuanForCausalLM": - self.logger.warning( - "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" - ) - ctx = LazyInitContext(default_device="cuda") - with ctx: - model = _supported_models[arch].from_pretrained( - model_or_path, trust_remote_code=True, torch_dtype=self.dtype - ) - pretrained_path = pretrained_utils.get_pretrained_path(model) - else: - # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate - raise ValueError(f"Model {arch} is not supported.") - - except Exception as e: - self.logger.error( - f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" - ) - else: - model = model_or_path - - self.model_config = model.config - - torch.cuda.empty_cache() - init_gpu_memory = torch.cuda.mem_get_info()[0] - - self.device = get_accelerator().get_current_device() - if self.verbose: - self.logger.info(f"the device is {self.device}") - - model = model.to(self.dtype).eval() - - if self.verbose: - self.logger.info( - f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__ + self.model_type = get_model_type(model_or_path=model_or_path) + self.engine = None + if self.model_type == ModelType.LLM: + from .llm_engine import LLMEngine + + self.engine = LLMEngine( + model_or_path=model_or_path, + tokenizer=tokenizer, + inference_config=inference_config, + verbose=verbose, + model_policy=model_policy, ) - - if model_policy is None: - prefix = "nopadding" if not self.inference_config.pad_input else "padding" - model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" - model_policy = model_policy_map.get(model_policy_key) - - if not isinstance(model_policy, Policy): - try: - model_policy = model_policy() - except Exception as e: - raise ValueError(f"Unable to instantiate model policy: {e}") - - assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" - pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) - tp_group = pg_mesh.get_group_along_axis(TP_AXIS) - - self.model = self._shardformer( - model, - model_policy, - model_shard_infer_config, - None, - tp_group=tp_group, - ) - - self.model = ModelWrapper(model).to(self.device) - - if self.verbose: - self.logger.info( - f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + elif self.model_type == ModelType.DIFFUSION_MODEL: + from .diffusion_engine import DiffusionEngine + + self.engine = DiffusionEngine( + model_or_path=model_or_path, + inference_config=inference_config, + verbose=verbose, + model_policy=model_policy, ) + elif self.model_type == ModelType.UNKNOWN: + self.logger.error(f"Model Type either Difffusion or LLM!") - if pretrained_path: - from colossalai.inference.core.plugin import InferCheckpoint_io - - cpt_io = InferCheckpoint_io() - if_has_index_file, model_index_file = has_index_file(pretrained_path) - assert if_has_index_file, "the model path is invalid" - cpt_io.load_model(self.model, model_index_file) - - free_gpu_memory, _ = torch.cuda.mem_get_info() - peak_memory = init_gpu_memory - free_gpu_memory - if self.verbose: - self.logger.info( - f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" - ) - - @torch.inference_mode() - def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): - assert self.use_cuda_graph, "please turn on the cuda graph" - - if self.verbose: - self.logger.info("Colossal AI CUDA Graph Capture begin") - - t_capture_begin = time.perf_counter() - - block_size = self.inference_config.block_size - head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads - - # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - max_context_len_to_capture = self.inference_config.max_context_len_to_capture - max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size - input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() - # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) - self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) - self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) - self.graph_block_tables[0, :] = np.arange( - 0, max_num_blocks - ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len - block_tables = torch.from_numpy(self.graph_block_tables).cuda() - output_tensor = torch.zeros( - (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device - ) - fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor - - max_num_seqs = self.inference_config.max_batch_size - batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] - sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() - # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len - sequence_lengths[0] = torch.tensor( - self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 - ).cuda() - - # NOTE: Capturing the largest batch size first may help reduce the - # memory usage of CUDA graph. - for batch_size in reversed(batch_size_capture_list): - if self.verbose: - self.logger.info(f"batch size {batch_size} graph capturing") - - input_meta_data = InputMetaData( - block_tables=block_tables[:batch_size], - sequence_lengths=sequence_lengths[:batch_size], - fd_inter_tensor=fd_inter_tensor, - batch_size=batch_size, - is_prompts=False, - use_cuda_graph=True, - high_precision=False, - kv_seq_len=sequence_lengths[:batch_size].max().item(), - head_dim=head_dim, - dtype=self.dtype, - ) - - graph_runner = CUDAGraphRunner(self.model) - graph_runner.capture( - input_tokens_ids[:batch_size], - output_tensor[:batch_size], - input_meta_data, - k_caches=k_cache, - v_caches=v_cache, - memory_pool=self.graph_memory_pool, - ) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[batch_size] = graph_runner - - t_capture_end = time.perf_counter() - - if self.verbose: - self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") + self._initialized = True + self._verify_args() def _verify_args(self) -> None: """Verify the input args""" - if not isinstance(self.inference_config, InferenceConfig): - raise TypeError("Invalid type of inference config provided.") - if not isinstance(self.model, nn.Module): - raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") - if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): - raise TypeError( - f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" - ) - if isinstance(self.model, ModelWrapper): - model = self.model.module - assert ( - model.__class__.__name__ in _supported_models.keys() - ), f"Model {self.model.__class__.__name__} is not supported." - - def _shardformer( - self, - model: nn.Module, - model_policy: Policy, - model_shard_infer_config: ModelShardInferenceConfig = None, - stage_manager: PipelineStageManager = None, - tp_group: ProcessGroupMesh = None, - ) -> nn.Module: - """ - Initialize ShardConfig and replace the model with shardformer. - - Args: - model (nn.Module): Path or nn.Module of this model. - model_policy (Policy): The policy to shardformer model which is determined by the model type. - stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. - tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. - - Returns: - nn.Module: The model optimized by Shardformer. - """ - - shardconfig = ShardConfig( - tensor_parallel_process_group=tp_group, - pipeline_stage_manager=stage_manager, - enable_tensor_parallelism=(self.inference_config.tp_size > 1), - enable_fused_normalization=False, - enable_all_optimization=False, - enable_flash_attention=False, - enable_jit_fused=False, - enable_sequence_parallelism=False, - extra_kwargs={"model_shard_infer_config": model_shard_infer_config}, - ) - shardformer = ShardFormer(shard_config=shardconfig) - shard_model, _ = shardformer.optimize(model, model_policy) - return shard_model - - def enable_spec_dec( - self, - drafter_model: nn.Module = None, - n_spec_tokens: int = None, - use_glide_drafter: bool = False, - ) -> None: - """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. - - Args: - drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. - If provided, the previous drafter and drafter model, if exist, will be overwritten. - n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. - If not provided, `max_n_spec_tokens` in InferenceConfig will be used. - use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False. - If True, the drafter model will be replaced by a glide model. - - ```python - ... - engine = InferenceEngine(model, tokenizer, inference_config) - - engine.enable_spec_dec(drafter_model, n_spec_tokens=5) - engine.generate(...) # Speculative Decoding - - engine.disable_spec_dec() - engine.generate(...) # Normal generation - - engine.enable_spec_dec() - engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens - engine.clear_spec_dec() - ``` - """ - - if drafter_model is None and self.drafter is None: - raise ValueError("Drafter not initialized. Please provide a Drafter Model") - if n_spec_tokens is not None: - assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens - self.n_spec_tokens = n_spec_tokens - if drafter_model is not None: - assert isinstance(drafter_model, nn.Module) - # overwrite the drafter, if exists - self.clear_spec_dec() - self.drafter_model = drafter_model - self.drafter = Drafter( - self.drafter_model, - self.tokenizer, - device=self.device, - dtype=self.dtype, - ) - - # check if the provided drafter model is compatible with GLIDE structure - # when `use_glide_drafter` is set to True - if ( - use_glide_drafter - and hasattr(drafter_model, "model") - and hasattr(drafter_model.model, "layers") - and hasattr(drafter_model.model.layers[0], "cross_attn") - ): - self.use_glide = use_glide_drafter - elif use_glide_drafter: - self.logger.warning( - f"`use_glide_drafter` is provided as {use_glide_drafter}, " - f"but the provided drafter model is not compatible with GLIDE structure." - f"Falling back to use the default drafter model (non-GLIDE)." - ) - self.request_handler.set_spec_dec_mode(self.n_spec_tokens) - # using speculative decoding for subsequent generations - self.use_spec_dec = True - - def disable_spec_dec(self) -> None: - """Disable using speculative decoding for subsequent generations.""" - self.request_handler.unset_spec_dec_mode() - # set back to the maximum number of tokens to speculate - self.n_spec_tokens = self.inference_config.max_n_spec_tokens - self.use_glide = False - self.use_spec_dec = False - - def clear_spec_dec(self) -> None: - """Clear relatable structures of speculative decoding, if exist.""" - if self.use_spec_dec: - self.disable_spec_dec() - if self.drafter_model or self.drafter: - self.drafter_model = None - self.drafter = None - torch.cuda.empty_cache() - self.use_glide = False - self.use_spec_dec = False - - def steps_spec_dec(self) -> List[Sequence]: - """ - Run Speculative Decoding steps. This is like retrieving a single batch and launch inference - with many steps of speculating by a drafter model as well as verifying by a main model. - - Returns: - List[Sequence]: finished sequences generated by one step. - """ - batch = self.request_handler.schedule() # prefill batch - assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." - - input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) - - if input_meta_data.use_cuda_graph: - model_executable = self.graph_runners[input_meta_data.batch_size] - else: - model_executable = self.model - - # 1. Prefill small model (Drafter) - fill past kv cache for drafter model - # NOTE For glide drafter models, we won't actually apply glide during prefill stage - drafter_out = self.drafter.speculate(input_token_ids, 1, None) - next_token_ids_spec = drafter_out.next_tokens - drafter_past_key_values = drafter_out.past_key_values - - # 2. Prefill main model (Verifier) - fill past kv cache for main model - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) - # append new inputs to the batch, temporarily - batch.append_batch_tokens(next_tokens) - self.request_handler.allocate_batch_spec_dec(batch, 1) - already_allocated_kv_len = batch.seq_lengths[0].item() - input_token_ids = batch.get_1D_inputs_spec_dec(1) - - finished_sequences = self.request_handler.update() - - while True: - # HACK Retrieve the running batch - # Using RequestHandler.schedule here will re-allocate same kv cache for the batch - batch = self.request_handler.running_bb # running batch - assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." - - # 3. Decoding - Drafter model speculates `n` tokens - glide_input = None - if self.use_glide: - glide_input = GlideInput( - batch.get_block_table_tensor(), - self.k_cache[-1], # use kv cahces of the last layer - self.v_cache[-1], - batch.get_sequence_lengths(), - n_spec_tokens=self.n_spec_tokens, - ) - - drafter_out = self.drafter.speculate( - input_token_ids, - self.n_spec_tokens, - drafter_past_key_values, - glide_input=glide_input, - ) - next_token_ids_spec = drafter_out.next_tokens - drafter_past_key_values = drafter_out.past_key_values - drafter_spec_length = drafter_out.speculated_length - - for next_token_id_spec in next_token_ids_spec: - self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) - cur_length = batch.seq_lengths[0].item() - if already_allocated_kv_len < cur_length: - self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len) - already_allocated_kv_len = cur_length - - # 4. Decoding - Main model verifies `n` tokens in parallel - if drafter_spec_length < batch.num_tokens_to_verify: - batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) - input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - - next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) - - # 5. Compare and process the results - diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) - n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() - - # revoke appended tokens for each Sequence in the current batch - batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens - - # append the last correct token generated by the main model - self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) - - # trim past key values of the drafter model - drafter_past_key_values = Drafter.trim_kv_cache( - drafter_past_key_values, drafter_spec_length - n_matches - 1 - ) - - # prepare inputs for the next round of speculation - n = 1 if n_matches < drafter_spec_length else 2 - input_token_ids = batch.get_1D_inputs_spec_dec(n) - - self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) - finished_sequences = self.request_handler.update() - if len(finished_sequences) > 0: - break - - # Reset back the number of speculated tokens of the batch, - # this is used to handle the last round of speculation, in which case the number of speculated tokens - # by the drafter is less than the number of speculated tokens set to the engine. - batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens) - - return finished_sequences + assert self.engine is not None, "Please init Engine first" + assert self._initialized, "Engine must be initialized" def generate( self, request_ids: Union[List[int], int] = None, prompts: Union[List[str], str] = None, - prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, - return_token_ids: bool = False, - generation_config: Optional[GenerationConfig] = None, - ) -> Union[List[str], Tuple[List[str], List[List[int]]]]: + *args, + **kwargs, + ) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]: """ Executing the inference step. Args: request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. - prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. - return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. - generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. - - Returns: - Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. - """ - - gen_config_dict = generation_config.to_dict() if generation_config is not None else {} - prompts = [prompts] if isinstance(prompts, str) else prompts - request_ids = [request_ids] if isinstance(request_ids, int) else request_ids - - with torch.inference_mode(): - if prompts is not None or prompts_token_ids is not None: - self.add_request( - request_ids=request_ids, - prompts=prompts, - prompts_token_ids=prompts_token_ids, - **gen_config_dict, - ) - - output_seqs_list = [] - total_tokens_list = [] - - # intuition: If user provide a generation config, we should replace the existing one. - if generation_config is not None: - self.generation_config = generation_config - self.generation_config_dict = gen_config_dict - - if self.use_spec_dec: - assert self.drafter is not None, "Drafter Model is not initialized." - while self.request_handler.check_unfinished_seqs(): - output_seqs_list += self.steps_spec_dec() - else: - while self.request_handler.check_unfinished_seqs(): - output_seqs_list += self.step() - - output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) - - for seq in output_seqs_list: - total_tokens_list.append(seq.input_token_id + seq.output_token_id) - - output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True) - - if return_token_ids: - output_tokens_list = [seq.output_token_id for seq in output_seqs_list] - return output_str, output_tokens_list - else: - return output_str - - @property - def has_prompt_template(self) -> bool: - """ """ - return self.inference_config.prompt_template is not None - - def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]: - """ - This method will format the input prompt according to the prompt template given to the InferenceConfig. """ - assert ( - self.has_prompt_template - ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig." - if isinstance(prompts, (list, tuple)): - return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] - elif isinstance(prompts, str): - return self.inference_config.prompt_template.format(input_text=prompts) - else: - raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") + assert self.engine is not None, "Please init Engine first" + return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs) def add_request( self, request_ids: Union[List[int], int] = None, prompts: Union[List[str], str] = None, - prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + *args, **kwargs, ) -> None: """ @@ -630,168 +98,36 @@ class InferenceEngine: request_ids (List[int], optional): The request ID. Defaults to None. prompts (Union[List[str], optional): Input prompts. Defaults to None. prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. + kwargs: for LLM, it could be max_length, max_new_tokens, etc + for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers """ + assert self.engine is not None, "Please init Engine first" + self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs) - # apply the prompt template to the input prompts - - if self.has_prompt_template and prompts is not None: - prompts = self.format_prompt(prompts) - - block_size = self.inference_config.block_size - - if request_ids is not None and not isinstance(request_ids, list): - request_ids = [request_ids] - - if prompts is not None and not isinstance(prompts, list): - prompts = [prompts] - - if prompts_token_ids is None: - assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." - prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ - "input_ids" - ] - - # list of torch Tensor - if isinstance(prompts_token_ids, list): - if isinstance(prompts_token_ids[0], torch.Tensor): - prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] - elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): - prompts_token_ids = prompts_token_ids.tolist() - else: - raise TypeError( - f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." - ) - - assert ( - len(prompts_token_ids[0]) <= self.inference_config.max_input_len - ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." - - prompts_num = len(prompts_token_ids) - - for i in range(prompts_num): - if request_ids: - assert isinstance( - request_ids[0], int - ), f"The request_id type must be int, but got {type(request_ids[0])}" - assert len(request_ids) == prompts_num - request_id = request_ids[i] - else: - request_id = next(self.counter) - if prompts == None: - prompt = None - else: - prompt = prompts[i] - - max_length = kwargs.get("max_length", None) - max_new_tokens = kwargs.get("max_new_tokens", None) - if max_length is None and max_new_tokens is None: - max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len - elif max_length is not None: - max_new_tokens = max_length - len(prompts_token_ids[i]) + def step(self): + assert self.engine is not None, "Please init Engine first" + return self.engine.step() - if not self.inference_config.enable_streamingllm: - assert ( - self.inference_config.max_output_len >= max_new_tokens - ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." - - sequence = Sequence( - request_id, - prompt, - prompts_token_ids[i], - block_size, - None, - self.tokenizer.eos_token_id, - self.tokenizer.pad_token_id, - max_output_len=max_new_tokens, - ignore_eos=self.inference_config.ignore_eos, - ) - self.request_handler.add_sequence(sequence) - - def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: - input_ids = batch.get_1D_inputs() - sequence_lengths = batch.get_sequence_lengths() - - if batch.is_prompts: - n_tokens = sequence_lengths.sum().item() - else: - n_tokens = batch.current_batch_size - if batch.use_spec_dec: - n_tokens = batch.num_tokens_to_verify + 1 - assert n_tokens == input_ids.size(0) - n_tokens = n_tokens * batch.current_batch_size - output_tensor = torch.zeros( - (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device - ) - - batch_token_ids = None - if ( - self.generation_config.repetition_penalty != 1.0 - or self.generation_config.no_repeat_ngram_size > 0 - or self.generation_config.forced_eos_token_id is not None - ): - batch_token_ids = batch.batch_token_ids - - # only when we have the graph for specific decoding batch size can we use the cuda graph for inference - use_cuda_graph = False - if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): - use_cuda_graph = True - - input_meta_data = InputMetaData( - block_tables=batch.get_block_table_tensor(), - sequence_lengths=sequence_lengths, - fd_inter_tensor=batch.fd_inter_tensor, - batch_size=batch.current_batch_size, - is_prompts=batch.is_prompts, - use_cuda_kernel=self.inference_config.use_cuda_kernel, - use_cuda_graph=use_cuda_graph, - high_precision=self.high_precision, - kv_seq_len=sequence_lengths.max().item(), - head_dim=batch.head_dim, - dtype=batch.dtype, - use_spec_dec=batch.use_spec_dec, - num_tokens_to_verify=batch.num_tokens_to_verify, - batch_token_ids=batch_token_ids, - ) - - return input_ids, output_tensor, input_meta_data - - def step(self) -> List[str]: + def __getattr__(self, name): """ - In each step, do the follows: - 1. Run RequestHandler.schedule() and get the batch used for inference. - 2. Get the input, inputinfo and output placeholder from the batchbucket - 3. Run model to generate the next token - 4. Update waiting list and running list in RequestHandler and get finished sequences. - 5. Decode and return finished sequences. - - Returns: - List[str]: Decoded finished sequences generated by one step. + The Design logic of getattr, setattr: + 1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine. + 2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx + So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine) """ - - batch = self.request_handler.schedule() - - input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) - - if input_meta_data.use_cuda_graph: - model_executable = self.graph_runners[input_meta_data.batch_size] + if self.__dict__.get("_initialized", False): + if name in self.__dict__: + return self.__dict__[name] + else: + return getattr(self.engine, name) else: - model_executable = self.model + return self.__dict__[name] - # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. - logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - if self.inference_config.pad_input: - logits = logits[:, -1, :] - - if self.inference_config.enable_streamingllm: - updated_block_ids = batch.streamingllm_update_batch( - self.inference_config.start_token_size, self.inference_config.generated_token_size - ) - self.request_handler.streamingllm_free_block_tables(updated_block_ids) - - next_tokens = search_tokens( - self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids - ) - self.request_handler.append_next_tokens(next_tokens) - finished_sequences = self.request_handler.update() - - return finished_sequences + def __setattr__(self, name, value): + if self.__dict__.get("_initialized", False): + if name in self.__dict__: + self.__dict__[name] = value + else: + setattr(self.engine, name, value) + else: + self.__dict__[name] = value diff --git a/colossalai/inference/core/llm_engine.py b/colossalai/inference/core/llm_engine.py new file mode 100644 index 000000000..b973d371d --- /dev/null +++ b/colossalai/inference/core/llm_engine.py @@ -0,0 +1,758 @@ +import time +from itertools import count +from typing import Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +from torch import distributed as dist +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + GenerationConfig, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from colossalai.accelerator import get_accelerator +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig +from colossalai.inference.graph_runner import CUDAGraphRunner +from colossalai.inference.modeling.policy import model_policy_map +from colossalai.inference.sampler import search_tokens +from colossalai.inference.spec import Drafter, GlideInput +from colossalai.inference.struct import Sequence +from colossalai.inference.utils import get_model_size, has_index_file +from colossalai.interface import ModelWrapper +from colossalai.lazy import LazyInitContext +from colossalai.logging import get_dist_logger +from colossalai.shardformer.policies.base_policy import Policy + +from .base_engine import BaseEngine +from .request_handler import RequestHandler + +PP_AXIS, TP_AXIS = 0, 1 + +_supported_models = { + "LlamaForCausalLM": LlamaForCausalLM, + "BaichuanForCausalLM": AutoModelForCausalLM, +} + +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] + + +class LLMEngine(BaseEngine): + """ + InferenceEngine which manages the inference process.. + + Args: + model_or_path (nn.Module or str): Path or nn.Module of this model. + tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. + inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. + verbose (bool): Determine whether or not to log the generation process. + model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. + """ + + def __init__( + self, + model_or_path: nn.Module | str, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None, + inference_config: InferenceConfig = None, + verbose: bool = False, + model_policy: Policy | type[Policy] = None, + ) -> None: + self.inference_config = inference_config + self.dtype = inference_config.dtype + self.high_precision = inference_config.high_precision + + self.verbose = verbose + self.logger = get_dist_logger(__name__) + self.model_shard_infer_config = inference_config.to_model_shard_inference_config() + + self.init_model(model_or_path, model_policy, self.model_shard_infer_config) + + self.generation_config = inference_config.to_generation_config(self.model_config) + self.generation_config_dict = self.generation_config.to_dict() + + self.tokenizer = tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.request_handler = RequestHandler(self.inference_config, self.model_config) + self.k_cache, self.v_cache = self.request_handler.get_kvcache() + # DISCUSS maybe move this into batch info? + + self.counter = count() + + self.use_cuda_graph = self.inference_config.use_cuda_graph + if self.use_cuda_graph: + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool = None # Set during graph capture. + if verbose: + self.logger.info("Colossal AI CUDA Graph Capture on") + + self.capture_model(self.k_cache, self.v_cache) + + # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` + self.use_spec_dec = self.inference_config.use_spec_dec + + self.drafter_model = None + self.drafter = None + self.use_glide = False + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + + self._verify_args() + + def init_model( + self, + model_or_path: Union[nn.Module, str], + model_policy: Union[Policy, Type[Policy]] = None, + model_shard_infer_config: ModelShardInferenceConfig = None, + ): + """ + Shard model or/and Load weight + + Args: + model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. + model_policy (Policy): the policy to replace the model. + model_inference_config: the configuration for modeling initialization when inference. + model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. + """ + pretrained_path = None + if isinstance(model_or_path, str): + import colossalai.interface.pretrained as pretrained_utils + + try: + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) + arch = getattr(hf_config, "architectures")[0] + if arch in _supported_models.keys(): + if arch == "BaichuanForCausalLM": + self.logger.warning( + "Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" + ) + ctx = LazyInitContext(default_device="cuda") + with ctx: + model = _supported_models[arch].from_pretrained( + model_or_path, trust_remote_code=True, torch_dtype=self.dtype + ) + pretrained_path = pretrained_utils.get_pretrained_path(model) + else: + # TODO(char-1ee): if the model not supported, use transformers APIs to load and generate + raise ValueError(f"Model {arch} is not supported.") + + except Exception as e: + self.logger.error( + f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" + ) + else: + model = model_or_path + + self.model_config = model.config + + torch.cuda.empty_cache() + init_gpu_memory = torch.cuda.mem_get_info()[0] + + self.device = get_accelerator().get_current_device() + if self.verbose: + self.logger.info(f"the device is {self.device}") + + model = model.to(self.dtype).eval() + + if self.verbose: + self.logger.info( + f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" + ) + + if model_policy is None: + prefix = "nopadding" if not self.inference_config.pad_input else "padding" + model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" + model_policy = model_policy_map.get(model_policy_key) + + if not isinstance(model_policy, Policy): + try: + model_policy = model_policy() + except Exception as e: + raise ValueError(f"Unable to instantiate model policy: {e}") + + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" + pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) + tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + + self.model = self._shardformer( + model, + model_policy, + model_shard_infer_config, + None, + tp_group=tp_group, + ) + + self.model = ModelWrapper(model).to(self.device) + + if self.verbose: + self.logger.info( + f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" + ) + + if pretrained_path: + from colossalai.inference.core.plugin import InferCheckpoint_io + + cpt_io = InferCheckpoint_io() + if_has_index_file, model_index_file = has_index_file(pretrained_path) + assert if_has_index_file, "the model path is invalid" + cpt_io.load_model(self.model, model_index_file) + + free_gpu_memory, _ = torch.cuda.mem_get_info() + peak_memory = init_gpu_memory - free_gpu_memory + if self.verbose: + self.logger.info( + f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" + ) + + @torch.inference_mode() + def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): + assert self.use_cuda_graph, "please turn on the cuda graph" + + if self.verbose: + self.logger.info("Colossal AI CUDA Graph Capture begin") + + t_capture_begin = time.perf_counter() + + block_size = self.inference_config.block_size + head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + max_context_len_to_capture = self.inference_config.max_context_len_to_capture + max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size + input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() + # self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) + self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) + self.graph_block_tables[0, :] = np.arange( + 0, max_num_blocks + ) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + output_tensor = torch.zeros( + (max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device + ) + fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor + + max_num_seqs = self.inference_config.max_batch_size + batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] + sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() + # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len + sequence_lengths[0] = torch.tensor( + self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 + ).cuda() + + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(batch_size_capture_list): + if self.verbose: + self.logger.info(f"batch size {batch_size} graph capturing") + + input_meta_data = InputMetaData( + block_tables=block_tables[:batch_size], + sequence_lengths=sequence_lengths[:batch_size], + fd_inter_tensor=fd_inter_tensor, + batch_size=batch_size, + is_prompts=False, + use_cuda_graph=True, + high_precision=False, + kv_seq_len=sequence_lengths[:batch_size].max().item(), + head_dim=head_dim, + dtype=self.dtype, + ) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_tokens_ids[:batch_size], + output_tensor[:batch_size], + input_meta_data, + k_caches=k_cache, + v_caches=v_cache, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + t_capture_end = time.perf_counter() + + if self.verbose: + self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") + + def _verify_args(self) -> None: + """Verify the input args""" + if not isinstance(self.inference_config, InferenceConfig): + raise TypeError("Invalid type of inference config provided.") + if not isinstance(self.model, nn.Module): + raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") + if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): + raise TypeError( + f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" + ) + if isinstance(self.model, ModelWrapper): + model = self.model.module + assert ( + model.__class__.__name__ in _supported_models.keys() + ), f"Model {self.model.__class__.__name__} is not supported." + + def enable_spec_dec( + self, + drafter_model: nn.Module = None, + n_spec_tokens: int = None, + use_glide_drafter: bool = False, + ) -> None: + """Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. + + Args: + drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. + If provided, the previous drafter and drafter model, if exist, will be overwritten. + n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. + If not provided, `max_n_spec_tokens` in InferenceConfig will be used. + use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False. + If True, the drafter model will be replaced by a glide model. + + ```python + ... + engine = InferenceEngine(model, tokenizer, inference_config) + + engine.enable_spec_dec(drafter_model, n_spec_tokens=5) + engine.generate(...) # Speculative Decoding + + engine.disable_spec_dec() + engine.generate(...) # Normal generation + + engine.enable_spec_dec() + engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens + engine.clear_spec_dec() + ``` + """ + + if drafter_model is None and self.drafter is None: + raise ValueError("Drafter not initialized. Please provide a Drafter Model") + if n_spec_tokens is not None: + assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens + self.n_spec_tokens = n_spec_tokens + if drafter_model is not None: + assert isinstance(drafter_model, nn.Module) + # overwrite the drafter, if exists + self.clear_spec_dec() + self.drafter_model = drafter_model + self.drafter = Drafter( + self.drafter_model, + self.tokenizer, + device=self.device, + dtype=self.dtype, + ) + + # check if the provided drafter model is compatible with GLIDE structure + # when `use_glide_drafter` is set to True + if ( + use_glide_drafter + and hasattr(drafter_model, "model") + and hasattr(drafter_model.model, "layers") + and hasattr(drafter_model.model.layers[0], "cross_attn") + ): + self.use_glide = use_glide_drafter + elif use_glide_drafter: + self.logger.warning( + f"`use_glide_drafter` is provided as {use_glide_drafter}, " + f"but the provided drafter model is not compatible with GLIDE structure." + f"Falling back to use the default drafter model (non-GLIDE)." + ) + self.request_handler.set_spec_dec_mode(self.n_spec_tokens) + # using speculative decoding for subsequent generations + self.use_spec_dec = True + + def disable_spec_dec(self) -> None: + """Disable using speculative decoding for subsequent generations.""" + self.request_handler.unset_spec_dec_mode() + # set back to the maximum number of tokens to speculate + self.n_spec_tokens = self.inference_config.max_n_spec_tokens + self.use_glide = False + self.use_spec_dec = False + + def clear_spec_dec(self) -> None: + """Clear relatable structures of speculative decoding, if exist.""" + if self.use_spec_dec: + self.disable_spec_dec() + if self.drafter_model or self.drafter: + self.drafter_model = None + self.drafter = None + torch.cuda.empty_cache() + self.use_glide = False + self.use_spec_dec = False + + def steps_spec_dec(self) -> List[Sequence]: + """ + Run Speculative Decoding steps. This is like retrieving a single batch and launch inference + with many steps of speculating by a drafter model as well as verifying by a main model. + + Returns: + List[Sequence]: finished sequences generated by one step. + """ + batch = self.request_handler.schedule() # prefill batch + assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." + + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + # 1. Prefill small model (Drafter) - fill past kv cache for drafter model + # NOTE For glide drafter models, we won't actually apply glide during prefill stage + drafter_out = self.drafter.speculate(input_token_ids, 1, None) + next_token_ids_spec = drafter_out.next_tokens + drafter_past_key_values = drafter_out.past_key_values + + # 2. Prefill main model (Verifier) - fill past kv cache for main model + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) + # append new inputs to the batch, temporarily + batch.append_batch_tokens(next_tokens) + self.request_handler.allocate_batch_spec_dec(batch, 1) + already_allocated_kv_len = batch.seq_lengths[0].item() + input_token_ids = batch.get_1D_inputs_spec_dec(1) + + finished_sequences = self.request_handler.update() + + while True: + # HACK Retrieve the running batch + # Using RequestHandler.schedule here will re-allocate same kv cache for the batch + batch = self.request_handler.running_bb # running batch + assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." + + # 3. Decoding - Drafter model speculates `n` tokens + glide_input = None + if self.use_glide: + glide_input = GlideInput( + batch.get_block_table_tensor(), + self.k_cache[-1], # use kv cahces of the last layer + self.v_cache[-1], + batch.get_sequence_lengths(), + n_spec_tokens=self.n_spec_tokens, + ) + + drafter_out = self.drafter.speculate( + input_token_ids, + self.n_spec_tokens, + drafter_past_key_values, + glide_input=glide_input, + ) + next_token_ids_spec = drafter_out.next_tokens + drafter_past_key_values = drafter_out.past_key_values + drafter_spec_length = drafter_out.speculated_length + + for next_token_id_spec in next_token_ids_spec: + self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) + cur_length = batch.seq_lengths[0].item() + if already_allocated_kv_len < cur_length: + self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len) + already_allocated_kv_len = cur_length + + # 4. Decoding - Main model verifies `n` tokens in parallel + if drafter_spec_length < batch.num_tokens_to_verify: + batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + + next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) + + # 5. Compare and process the results + diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) + n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() + + # revoke appended tokens for each Sequence in the current batch + batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens + + # append the last correct token generated by the main model + self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) + + # trim past key values of the drafter model + drafter_past_key_values = Drafter.trim_kv_cache( + drafter_past_key_values, drafter_spec_length - n_matches - 1 + ) + + # prepare inputs for the next round of speculation + n = 1 if n_matches < drafter_spec_length else 2 + input_token_ids = batch.get_1D_inputs_spec_dec(n) + + self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) + finished_sequences = self.request_handler.update() + if len(finished_sequences) > 0: + break + + # Reset back the number of speculated tokens of the batch, + # this is used to handle the last round of speculation, in which case the number of speculated tokens + # by the drafter is less than the number of speculated tokens set to the engine. + batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens) + + return finished_sequences + + def generate( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + return_token_ids: bool = False, + generation_config: Optional[GenerationConfig] = None, + ) -> Union[List[str], Tuple[List[str], List[List[int]]]]: + """ + Executing the inference step. + + Args: + request_ids (List[int], optional): The request ID. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. + return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. + generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. + + Returns: + Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. + """ + + gen_config_dict = generation_config.to_dict() if generation_config is not None else {} + prompts = [prompts] if isinstance(prompts, str) else prompts + request_ids = [request_ids] if isinstance(request_ids, int) else request_ids + + with torch.inference_mode(): + if prompts is not None or prompts_token_ids is not None: + self.add_request( + request_ids=request_ids, + prompts=prompts, + prompts_token_ids=prompts_token_ids, + **gen_config_dict, + ) + + output_seqs_list = [] + total_tokens_list = [] + + # intuition: If user provide a generation config, we should replace the existing one. + if generation_config is not None: + self.generation_config = generation_config + self.generation_config_dict = gen_config_dict + + if self.use_spec_dec: + assert self.drafter is not None, "Drafter Model is not initialized." + while self.request_handler.check_unfinished_reqs(): + output_seqs_list += self.steps_spec_dec() + else: + while self.request_handler.check_unfinished_reqs(): + output_seqs_list += self.step() + + output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) + + for seq in output_seqs_list: + total_tokens_list.append(seq.input_token_id + seq.output_token_id) + + output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True) + + if return_token_ids: + output_tokens_list = [seq.output_token_id for seq in output_seqs_list] + return output_str, output_tokens_list + else: + return output_str + + @property + def has_prompt_template(self) -> bool: + """ """ + return self.inference_config.prompt_template is not None + + def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]: + """ + This method will format the input prompt according to the prompt template given to the InferenceConfig. + """ + assert ( + self.has_prompt_template + ), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig." + + if isinstance(prompts, (list, tuple)): + return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] + elif isinstance(prompts, str): + return self.inference_config.prompt_template.format(input_text=prompts) + else: + raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") + + def add_request( + self, + request_ids: Union[List[int], int] = None, + prompts: Union[List[str], str] = None, + prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, + **kwargs, + ) -> None: + """ + Add requests. + + Args: + request_ids (List[int], optional): The request ID. Defaults to None. + prompts (Union[List[str], optional): Input prompts. Defaults to None. + prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. + """ + + # apply the prompt template to the input prompts + + if self.has_prompt_template and prompts is not None: + prompts = self.format_prompt(prompts) + + block_size = self.inference_config.block_size + + if request_ids is not None and not isinstance(request_ids, list): + request_ids = [request_ids] + + if prompts is not None and not isinstance(prompts, list): + prompts = [prompts] + + if prompts_token_ids is None: + assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." + prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ + "input_ids" + ] + + # list of torch Tensor + if isinstance(prompts_token_ids, list): + if isinstance(prompts_token_ids[0], torch.Tensor): + prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] + elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): + prompts_token_ids = prompts_token_ids.tolist() + else: + raise TypeError( + f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." + ) + + assert ( + len(prompts_token_ids[0]) <= self.inference_config.max_input_len + ), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." + + prompts_num = len(prompts_token_ids) + + for i in range(prompts_num): + if request_ids: + assert isinstance( + request_ids[0], int + ), f"The request_id type must be int, but got {type(request_ids[0])}" + assert len(request_ids) == prompts_num + request_id = request_ids[i] + else: + request_id = next(self.counter) + if prompts == None: + prompt = None + else: + prompt = prompts[i] + + max_length = kwargs.get("max_length", None) + max_new_tokens = kwargs.get("max_new_tokens", None) + if max_length is None and max_new_tokens is None: + max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len + elif max_length is not None: + max_new_tokens = max_length - len(prompts_token_ids[i]) + + if not self.inference_config.enable_streamingllm: + assert ( + self.inference_config.max_output_len >= max_new_tokens + ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." + + sequence = Sequence( + request_id, + prompt, + prompts_token_ids[i], + block_size, + None, + self.tokenizer.eos_token_id, + self.tokenizer.pad_token_id, + max_output_len=max_new_tokens, + ignore_eos=self.inference_config.ignore_eos, + ) + self.request_handler.add_sequence(sequence) + + def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: + input_ids = batch.get_1D_inputs() + sequence_lengths = batch.get_sequence_lengths() + + if batch.is_prompts: + n_tokens = sequence_lengths.sum().item() + else: + n_tokens = batch.current_batch_size + if batch.use_spec_dec: + n_tokens = batch.num_tokens_to_verify + 1 + assert n_tokens == input_ids.size(0) + n_tokens = n_tokens * batch.current_batch_size + output_tensor = torch.zeros( + (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device + ) + + batch_token_ids = None + if ( + self.generation_config.repetition_penalty != 1.0 + or self.generation_config.no_repeat_ngram_size > 0 + or self.generation_config.forced_eos_token_id is not None + ): + batch_token_ids = batch.batch_token_ids + + # only when we have the graph for specific decoding batch size can we use the cuda graph for inference + use_cuda_graph = False + if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): + use_cuda_graph = True + + input_meta_data = InputMetaData( + block_tables=batch.get_block_table_tensor(), + sequence_lengths=sequence_lengths, + fd_inter_tensor=batch.fd_inter_tensor, + batch_size=batch.current_batch_size, + is_prompts=batch.is_prompts, + use_cuda_kernel=self.inference_config.use_cuda_kernel, + use_cuda_graph=use_cuda_graph, + high_precision=self.high_precision, + kv_seq_len=sequence_lengths.max().item(), + head_dim=batch.head_dim, + dtype=batch.dtype, + use_spec_dec=batch.use_spec_dec, + num_tokens_to_verify=batch.num_tokens_to_verify, + batch_token_ids=batch_token_ids, + ) + + return input_ids, output_tensor, input_meta_data + + def step(self) -> List[str]: + """ + In each step, do the follows: + 1. Run RequestHandler.schedule() and get the batch used for inference. + 2. Get the input, inputinfo and output placeholder from the batchbucket + 3. Run model to generate the next token + 4. Update waiting list and running list in RequestHandler and get finished sequences. + 5. Decode and return finished sequences. + + Returns: + List[str]: Decoded finished sequences generated by one step. + """ + + batch = self.request_handler.schedule() + + input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) + + if input_meta_data.use_cuda_graph: + model_executable = self.graph_runners[input_meta_data.batch_size] + else: + model_executable = self.model + + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. + logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) + if self.inference_config.pad_input: + logits = logits[:, -1, :] + + if self.inference_config.enable_streamingllm: + updated_block_ids = batch.streamingllm_update_batch( + self.inference_config.start_token_size, self.inference_config.generated_token_size + ) + self.request_handler.streamingllm_free_block_tables(updated_block_ids) + + next_tokens = search_tokens( + self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids + ) + self.request_handler.append_next_tokens(next_tokens) + finished_sequences = self.request_handler.update() + + return finished_sequences diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 512eaea71..393347c31 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket from colossalai.inference.config import InferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager -from colossalai.inference.struct import RequestStatus, Sequence +from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -98,7 +98,46 @@ class RunningList: self._decoding[seq_id] = self._prefill.pop(seq_id) -class RequestHandler: +class NaiveRequestHandler: + def __init__(self) -> None: + self.running_list: List[DiffusionSequence] = [] + self.waiting_list: List[str] = [] + + def _has_waiting(self) -> bool: + return any(lst for lst in self.waiting_list) + + def _has_running(self) -> bool: + return any(lst for lst in self.running_list) + + def check_unfinished_reqs(self): + return self._has_waiting() or self._has_running() + + def add_sequence(self, seq: DiffusionSequence): + """ + Add the request to waiting list. + """ + assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists." + self.waiting_list.append(seq) + + def _find_sequence(self, request_id: int) -> DiffusionSequence: + """ + Find the request by request_id. + """ + for lst in enumerate(self.waiting_list + self.running_list): + for seq in lst: + if seq.request_id == request_id: + return seq + return None + + def schedule(self): + ret = None + if self._has_waiting: + ret = self.waiting_list[0] + self.waiting_list = self.waiting_list[1:] + return ret + + +class RequestHandler(NaiveRequestHandler): """ RequestHandler is the core for handling existing requests and updating current batch. During generation process, we call schedule function each iteration to update current batch. @@ -176,12 +215,12 @@ class RequestHandler: generated_token_size=inference_config.generated_token_size, ) + def _has_running(self) -> bool: + return not self.running_bb.is_empty() + def _init_cache(self, model_config): self.cache_manager = KVCacheManager(self.inference_config, model_config) - def _has_waiting(self) -> bool: - return any(lst for lst in self.waiting_list) - def get_kvcache(self): return self.cache_manager.get_kv_cache() @@ -318,7 +357,7 @@ class RequestHandler: if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens: seq.mark_finished() - def check_unfinished_seqs(self) -> bool: + def check_unfinished_reqs(self) -> bool: return self._has_waiting() or not self.running_list.is_empty() def total_requests_in_batch_bucket(self) -> int: diff --git a/colossalai/inference/modeling/models/diffusion.py b/colossalai/inference/modeling/models/diffusion.py new file mode 100644 index 000000000..9dc90733d --- /dev/null +++ b/colossalai/inference/modeling/models/diffusion.py @@ -0,0 +1,54 @@ +import inspect +import types + +import torch +from torch import nn + + +class DiffusionPipe(nn.Module): + """ + This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property. + """ + + def __init__(self, source_obj) -> None: + super(DiffusionPipe, self).__init__() + + for k, v in source_obj.__dict__.items(): + if isinstance(v, nn.Module): + self.add_module(k, v) + else: + setattr(self, k, v) + + skip_list = ["_execution_device", "to", "device"] # this + + for name, member in inspect.getmembers(source_obj.__class__): + if name in skip_list: + continue + if not name.startswith("__") and not name.endswith("__"): + if isinstance(member, property): + setattr(self.__class__, name, member) + elif inspect.isfunction(member) or inspect.ismethod(member): + bound_method = types.MethodType(member, self) + setattr(self, name, bound_method) + elif not callable(member) and not isinstance(member, property): + setattr(self, name, member) + elif name == "__call__": + bound_method = types.MethodType(member, self) + setattr(self, "_forward", bound_method) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + # return self.device + return torch.device("cuda") + + @property + def device(self): + next(self.parameters()).device + + def forward(self, *args, **kwargs): + return self._forward(*args, **kwargs) diff --git a/colossalai/inference/modeling/models/pixart_alpha.py b/colossalai/inference/modeling/models/pixart_alpha.py new file mode 100644 index 000000000..d5774946e --- /dev/null +++ b/colossalai/inference/modeling/models/pixart_alpha.py @@ -0,0 +1,220 @@ +# Code adapted from: +# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py + +from typing import Callable, List, Optional, Union + +import PIL.Image +import torch +from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + +from colossalai.logging import get_dist_logger + +from .diffusion import DiffusionPipe + +logger = get_dist_logger(__name__) + + +@torch.no_grad() +def pixart_alpha_forward( + self: DiffusionPipe, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 120, + **kwargs, +) -> PIL.Image: + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.transformer.config.sample_size == 128: + resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + if num_inference_steps == 1: + # For DMD one step sampling: https://arxiv.org/abs/2311.18828 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + # self.maybe_free_model_hooks() + + return image diff --git a/colossalai/inference/modeling/models/stablediffusion3.py b/colossalai/inference/modeling/models/stablediffusion3.py new file mode 100644 index 000000000..d1c63a6dc --- /dev/null +++ b/colossalai/inference/modeling/models/stablediffusion3.py @@ -0,0 +1,178 @@ +# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps + +from .diffusion import DiffusionPipe + + +# TODO(@lry89757) temporarily image, please support more return output +@torch.no_grad() +def sd3_forward( + self: DiffusionPipe, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], +): + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + return image diff --git a/colossalai/inference/modeling/policy/__init__.py b/colossalai/inference/modeling/policy/__init__.py index fa0395590..02ffadd9f 100644 --- a/colossalai/inference/modeling/policy/__init__.py +++ b/colossalai/inference/modeling/policy/__init__.py @@ -1,16 +1,22 @@ from .glide_llama import GlideLlamaModelPolicy from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy from .nopadding_llama import NoPaddingLlamaModelInferPolicy +from .pixart_alpha import PixArtAlphaInferPolicy +from .stablediffusion3 import StableDiffusion3InferPolicy model_policy_map = { "nopadding_llama": NoPaddingLlamaModelInferPolicy, "nopadding_baichuan": NoPaddingBaichuanModelInferPolicy, "glide_llama": GlideLlamaModelPolicy, + "StableDiffusion3Pipeline": StableDiffusion3InferPolicy, + "PixArtAlphaPipeline": PixArtAlphaInferPolicy, } __all__ = [ "NoPaddingLlamaModelInferPolicy", "NoPaddingBaichuanModelInferPolicy", "GlideLlamaModelPolicy", + "StableDiffusion3InferPolicy", + "PixArtAlphaInferPolicy", "model_polic_map", ] diff --git a/colossalai/inference/modeling/policy/pixart_alpha.py b/colossalai/inference/modeling/policy/pixart_alpha.py new file mode 100644 index 000000000..356056ba7 --- /dev/null +++ b/colossalai/inference/modeling/policy/pixart_alpha.py @@ -0,0 +1,34 @@ +from torch import nn + +from colossalai.inference.config import RPC_PARAM +from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward +from colossalai.shardformer.policies.base_policy import Policy + + +class PixArtAlphaInferPolicy(Policy, RPC_PARAM): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = {} + self.append_or_create_method_replacement( + description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe + ) + return policy + + def preprocess(self) -> nn.Module: + return self.model + + def postprocess(self): + return self.model + + def config_sanity_check(self): + pass + + def to_rpc_param(self) -> str: + return __class__.__name__ + + @staticmethod + def from_rpc_param() -> "PixArtAlphaInferPolicy": + return PixArtAlphaInferPolicy() diff --git a/colossalai/inference/modeling/policy/stablediffusion3.py b/colossalai/inference/modeling/policy/stablediffusion3.py new file mode 100644 index 000000000..c9877f7dc --- /dev/null +++ b/colossalai/inference/modeling/policy/stablediffusion3.py @@ -0,0 +1,34 @@ +from torch import nn + +from colossalai.inference.config import RPC_PARAM +from colossalai.inference.modeling.models.diffusion import DiffusionPipe +from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward +from colossalai.shardformer.policies.base_policy import Policy + + +class StableDiffusion3InferPolicy(Policy, RPC_PARAM): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = {} + self.append_or_create_method_replacement( + description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe + ) + return policy + + def preprocess(self) -> nn.Module: + return self.model + + def postprocess(self): + return self.model + + def config_sanity_check(self): + pass + + def to_rpc_param(self) -> str: + return __class__.__name__ + + @staticmethod + def from_rpc_param() -> "StableDiffusion3InferPolicy": + return StableDiffusion3InferPolicy() diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 1a3094a27..65d284296 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -2,6 +2,7 @@ import enum from dataclasses import dataclass from typing import Any, List +from colossalai.inference.config import DiffusionGenerationConfig from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -46,6 +47,17 @@ class RequestStatus(enum.Enum): return status == RequestStatus.WAITING +@dataclass +class DiffusionSequence: + """ + parameters for diffusion + """ + + request_id: int + prompt: str + generation_config: DiffusionGenerationConfig + + @dataclass class Sequence: """Store information of input sequence. diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 332e84d37..f2a0fc037 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -5,10 +5,12 @@ Utils for model inference import math import os import re +from enum import Enum from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch +from diffusers import DiffusionPipeline from torch import nn from colossalai.logging import get_dist_logger @@ -159,3 +161,38 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool: except ImportError: logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") return False + + +class ModelType(Enum): + DIFFUSION_MODEL = "Diffusion Model" + LLM = "Large Language Model (LLM)" + UNKNOWN = "Unknown Model Type" + + +def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]): + if isinstance(model_or_path, DiffusionPipeline): + return ModelType.DIFFUSION_MODEL + elif isinstance(model_or_path, nn.Module): + return ModelType.LLM + elif isinstance(model_or_path, str): + try: + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) + return ModelType.LLM + except: + """ + model type is not `ModelType.LLM` + """ + + try: + from diffusers import DiffusionPipeline + + DiffusionPipeline.load_config(model_or_path) + return ModelType.DIFFUSION_MODEL + except: + """ + model type is not `ModelType.DIFFUSION_MODEL` + """ + else: + return ModelType.UNKNOWN diff --git a/examples/inference/stable_diffusion/sd3_generation.py b/examples/inference/stable_diffusion/sd3_generation.py new file mode 100644 index 000000000..fe989eed7 --- /dev/null +++ b/examples/inference/stable_diffusion/sd3_generation.py @@ -0,0 +1,75 @@ +import argparse + +from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline +from torch import bfloat16, float16, float32 + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy +from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy + +# For Stable Diffusion 3, we'll use the following configuration +MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0] +POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0] + +TORCH_DTYPE_MAP = { + "fp16": float16, + "fp32": float32, + "bf16": bfloat16, +} + + +def infer(args): + # ============================== + # Launch colossalai, setup distributed environment + # ============================== + colossalai.launch_from_torch() + coordinator = DistCoordinator() + + # ============================== + # Load model and tokenizer + # ============================== + model_path_or_name = args.model + model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None)) + + # ============================== + # Initialize InferenceEngine + # ============================== + coordinator.print_on_master(f"Initializing Inference Engine...") + inference_config = InferenceConfig( + dtype=args.dtype, + max_batch_size=args.max_batch_size, + tp_size=args.tp_size, + use_cuda_kernel=args.use_cuda_kernel, + ) + engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True) + + # ============================== + # Generation + # ============================== + coordinator.print_on_master(f"Generating...") + out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0] + out.save("cat.jpg") + coordinator.print_on_master(out) + + +# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH +# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 + + +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--model", type=str, help="Path to the model or model name") + parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size") + parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt") + parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size") + parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) + parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default") + args = parser.parse_args() + + infer(args) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 27bbc3769..b54d1cf91 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -23,3 +23,4 @@ rpyc==6.0.0 fastapi uvicorn==0.29.0 galore_torch +diffusers==0.29.0