mirror of https://github.com/hpcaitech/ColossalAI
[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha supportpull/5894/head
parent
8ec24b6a4d
commit
cba20525a8
|
@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.generation import GenerationConfig
|
||||
|
@ -396,3 +396,49 @@ class ModelShardInferenceConfig:
|
|||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionGenerationConfig:
|
||||
"""
|
||||
Param for diffusion model forward
|
||||
"""
|
||||
|
||||
prompt_2: Optional[Union[str, List[str]]] = None
|
||||
prompt_3: Optional[Union[str, List[str]]] = None
|
||||
height: Optional[int] = None
|
||||
width: Optional[int] = None
|
||||
num_inference_steps: int = None
|
||||
timesteps: List[int] = None
|
||||
guidance_scale: float = None
|
||||
negative_prompt: Optional[Union[str, List[str]]] = (
|
||||
None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
|
||||
)
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None
|
||||
num_images_per_prompt: Optional[int] = None
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
|
||||
latents: Optional[torch.FloatTensor] = None
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
|
||||
output_type: Optional[str] = None # "pil"
|
||||
return_dict: bool = None
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||
clip_skip: Optional[int] = None
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
|
||||
callback_on_step_end_tensor_inputs: List[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
# NOTE(@lry89757) Only return the dict that not the default value None
|
||||
result = {}
|
||||
for field in fields(self):
|
||||
value = getattr(self, field.name)
|
||||
if value is not None:
|
||||
result[field.name] = value
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
|
||||
return cls(**kwargs)
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
|
||||
"""
|
||||
Init Model for Engine
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
|
||||
"""
|
||||
Generate ouptput for coming requests
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, prompts, request_ids=None, **kwargs):
|
||||
"""
|
||||
Add new request to Engine
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def step(self):
|
||||
"""
|
||||
Perform one new step forward
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _verify_args(self):
|
||||
"""
|
||||
Verify the parameters and members of class
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self):
|
||||
"""
|
||||
Use cuda graph to capture model
|
||||
"""
|
||||
return NotImplementedError("This method should be implemented by subclasses")
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Initialize ShardConfig and replace the model with shardformer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model optimized by Shardformer.
|
||||
"""
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model
|
|
@ -0,0 +1,200 @@
|
|||
from itertools import count
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from torch import distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
|
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.struct import DiffusionSequence
|
||||
from colossalai.inference.utils import get_model_size, get_model_type
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .base_engine import BaseEngine
|
||||
from .request_handler import NaiveRequestHandler
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
|
||||
class DiffusionEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: DiffusionPipeline | str,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy | type[Policy] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
self.high_precision = inference_config.high_precision
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
|
||||
self.model_type = get_model_type(model_or_path=model_or_path)
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.request_handler = NaiveRequestHandler()
|
||||
|
||||
self.counter = count()
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"
|
||||
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[str, nn.Module, DiffusionPipeline],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
if isinstance(model_or_path, str):
|
||||
model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
|
||||
policy_map_key = model.__class__.__name__
|
||||
model = DiffusionPipe(model)
|
||||
elif isinstance(model_or_path, DiffusionPipeline):
|
||||
policy_map_key = model_or_path.__class__.__name__
|
||||
model = DiffusionPipe(model_or_path)
|
||||
else:
|
||||
self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
model_policy = model_policy_map.get(policy_map_key)
|
||||
|
||||
if not isinstance(model_policy, Policy):
|
||||
try:
|
||||
model_policy = model_policy()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = model.to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
generation_config: DiffusionGenerationConfig = None,
|
||||
**kwargs,
|
||||
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
|
||||
""" """
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||
|
||||
with torch.inference_mode():
|
||||
if prompts is not None:
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
**gen_config_dict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output_reqs_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
self.generation_config_dict = gen_config_dict
|
||||
|
||||
while self.request_handler.check_unfinished_reqs():
|
||||
output_reqs_list += self.step()
|
||||
|
||||
return output_reqs_list
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
prompts: Union[List[str], str],
|
||||
request_ids: Union[List[int], int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs)
|
||||
prompts_num = len(prompts)
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
|
||||
seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config)
|
||||
|
||||
self.request_handler.add_sequence(seq)
|
||||
|
||||
def step(self) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. run forward to get List[Image]
|
||||
Returns:
|
||||
List[PIL.Image.Image]: Image Generated by one step.
|
||||
"""
|
||||
|
||||
input = self.request_handler.schedule()
|
||||
ret = self.model(prompt=input.prompt, **input.generation_config.to_dict())
|
||||
return ret
|
|
@ -1,57 +1,24 @@
|
|||
import time
|
||||
from itertools import count
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import PIL.Image
|
||||
import torch.nn as nn
|
||||
from torch import distributed as dist
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.utils import ModelType, get_model_type
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .request_handler import RequestHandler
|
||||
|
||||
__all__ = ["InferenceEngine"]
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = {
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"BaichuanForCausalLM": AutoModelForCausalLM,
|
||||
}
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
Args:
|
||||
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
||||
model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model.
|
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
|
@ -60,567 +27,68 @@ class InferenceEngine:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: InferenceConfig,
|
||||
model_or_path: Union[nn.Module, str, DiffusionPipeline],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
self.high_precision = inference_config.high_precision
|
||||
self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__
|
||||
self.model_type = get_model_type(model_or_path=model_or_path)
|
||||
self.engine = None
|
||||
if self.model_type == ModelType.LLM:
|
||||
from .llm_engine import LLMEngine
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
self.engine = LLMEngine(
|
||||
model_or_path=model_or_path,
|
||||
tokenizer=tokenizer,
|
||||
inference_config=inference_config,
|
||||
verbose=verbose,
|
||||
model_policy=model_policy,
|
||||
)
|
||||
elif self.model_type == ModelType.DIFFUSION_MODEL:
|
||||
from .diffusion_engine import DiffusionEngine
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||
# DISCUSS maybe move this into batch info?
|
||||
|
||||
self.counter = count()
|
||||
|
||||
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||
if self.use_cuda_graph:
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
self.graph_memory_pool = None # Set during graph capture.
|
||||
if verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture on")
|
||||
|
||||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.engine = DiffusionEngine(
|
||||
model_or_path=model_or_path,
|
||||
inference_config=inference_config,
|
||||
verbose=verbose,
|
||||
model_policy=model_policy,
|
||||
)
|
||||
elif self.model_type == ModelType.UNKNOWN:
|
||||
self.logger.error(f"Model Type either Difffusion or LLM!")
|
||||
|
||||
self._initialized = True
|
||||
self._verify_args()
|
||||
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
pretrained_path = None
|
||||
if isinstance(model_or_path, str):
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
if arch in _supported_models.keys():
|
||||
if arch is "BaichuanForCausalLM":
|
||||
self.logger.warning(
|
||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||
)
|
||||
ctx = LazyInitContext(default_device="cuda")
|
||||
with ctx:
|
||||
model = _supported_models[arch].from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
)
|
||||
else:
|
||||
model = model_or_path
|
||||
|
||||
self.model_config = model.config
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
model = model.to(self.dtype).eval()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
||||
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
||||
model_policy = model_policy_map.get(model_policy_key)
|
||||
|
||||
if not isinstance(model_policy, Policy):
|
||||
try:
|
||||
model_policy = model_policy()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = ModelWrapper(model).to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if pretrained_path:
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
||||
|
||||
t_capture_begin = time.perf_counter()
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
|
||||
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
||||
self.graph_block_tables[0, :] = np.arange(
|
||||
0, max_num_blocks
|
||||
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
output_tensor = torch.zeros(
|
||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
||||
)
|
||||
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
|
||||
|
||||
max_num_seqs = self.inference_config.max_batch_size
|
||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
||||
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
sequence_lengths[0] = torch.tensor(
|
||||
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
||||
).cuda()
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
if self.verbose:
|
||||
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=block_tables[:batch_size],
|
||||
sequence_lengths=sequence_lengths[:batch_size],
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
batch_size=batch_size,
|
||||
is_prompts=False,
|
||||
use_cuda_graph=True,
|
||||
high_precision=False,
|
||||
kv_seq_len=sequence_lengths[:batch_size].max().item(),
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens_ids[:batch_size],
|
||||
output_tensor[:batch_size],
|
||||
input_meta_data,
|
||||
k_caches=k_cache,
|
||||
v_caches=v_cache,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
|
||||
t_capture_end = time.perf_counter()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
"""Verify the input args"""
|
||||
if not isinstance(self.inference_config, InferenceConfig):
|
||||
raise TypeError("Invalid type of inference config provided.")
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
if isinstance(self.model, ModelWrapper):
|
||||
model = self.model.module
|
||||
assert (
|
||||
model.__class__.__name__ in _supported_models.keys()
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Initialize ShardConfig and replace the model with shardformer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model optimized by Shardformer.
|
||||
"""
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model
|
||||
|
||||
def enable_spec_dec(
|
||||
self,
|
||||
drafter_model: nn.Module = None,
|
||||
n_spec_tokens: int = None,
|
||||
use_glide_drafter: bool = False,
|
||||
) -> None:
|
||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||
|
||||
Args:
|
||||
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
||||
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
||||
If True, the drafter model will be replaced by a glide model.
|
||||
|
||||
```python
|
||||
...
|
||||
engine = InferenceEngine(model, tokenizer, inference_config)
|
||||
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
engine.generate(...) # Speculative Decoding
|
||||
|
||||
engine.disable_spec_dec()
|
||||
engine.generate(...) # Normal generation
|
||||
|
||||
engine.enable_spec_dec()
|
||||
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
||||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
||||
self.n_spec_tokens = n_spec_tokens
|
||||
if drafter_model is not None:
|
||||
assert isinstance(drafter_model, nn.Module)
|
||||
# overwrite the drafter, if exists
|
||||
self.clear_spec_dec()
|
||||
self.drafter_model = drafter_model
|
||||
self.drafter = Drafter(
|
||||
self.drafter_model,
|
||||
self.tokenizer,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# check if the provided drafter model is compatible with GLIDE structure
|
||||
# when `use_glide_drafter` is set to True
|
||||
if (
|
||||
use_glide_drafter
|
||||
and hasattr(drafter_model, "model")
|
||||
and hasattr(drafter_model.model, "layers")
|
||||
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
||||
):
|
||||
self.use_glide = use_glide_drafter
|
||||
elif use_glide_drafter:
|
||||
self.logger.warning(
|
||||
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
||||
f"but the provided drafter model is not compatible with GLIDE structure."
|
||||
f"Falling back to use the default drafter model (non-GLIDE)."
|
||||
)
|
||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||
# using speculative decoding for subsequent generations
|
||||
self.use_spec_dec = True
|
||||
|
||||
def disable_spec_dec(self) -> None:
|
||||
"""Disable using speculative decoding for subsequent generations."""
|
||||
self.request_handler.unset_spec_dec_mode()
|
||||
# set back to the maximum number of tokens to speculate
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def clear_spec_dec(self) -> None:
|
||||
"""Clear relatable structures of speculative decoding, if exist."""
|
||||
if self.use_spec_dec:
|
||||
self.disable_spec_dec()
|
||||
if self.drafter_model or self.drafter:
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
torch.cuda.empty_cache()
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]:
|
||||
"""
|
||||
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
||||
with many steps of speculating by a drafter model as well as verifying by a main model.
|
||||
|
||||
Returns:
|
||||
List[Sequence]: finished sequences generated by one step.
|
||||
"""
|
||||
batch = self.request_handler.schedule() # prefill batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
||||
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
|
||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
# append new inputs to the batch, temporarily
|
||||
batch.append_batch_tokens(next_tokens)
|
||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(1)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
while True:
|
||||
# HACK Retrieve the running batch
|
||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||
batch = self.request_handler.running_bb # running batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens
|
||||
glide_input = None
|
||||
if self.use_glide:
|
||||
glide_input = GlideInput(
|
||||
batch.get_block_table_tensor(),
|
||||
self.k_cache[-1], # use kv cahces of the last layer
|
||||
self.v_cache[-1],
|
||||
batch.get_sequence_lengths(),
|
||||
n_spec_tokens=self.n_spec_tokens,
|
||||
)
|
||||
|
||||
drafter_out = self.drafter.speculate(
|
||||
input_token_ids,
|
||||
self.n_spec_tokens,
|
||||
drafter_past_key_values,
|
||||
glide_input=glide_input,
|
||||
)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
drafter_spec_length = drafter_out.speculated_length
|
||||
|
||||
for next_token_id_spec in next_token_ids_spec:
|
||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||
cur_length = batch.seq_lengths[0].item()
|
||||
if already_allocated_kv_len < cur_length:
|
||||
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
||||
already_allocated_kv_len = cur_length
|
||||
|
||||
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||
if drafter_spec_length < batch.num_tokens_to_verify:
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
|
||||
# 5. Compare and process the results
|
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||
|
||||
# revoke appended tokens for each Sequence in the current batch
|
||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||
|
||||
# append the last correct token generated by the main model
|
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||
|
||||
# trim past key values of the drafter model
|
||||
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||
)
|
||||
|
||||
# prepare inputs for the next round of speculation
|
||||
n = 1 if n_matches < drafter_spec_length else 2
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(n)
|
||||
|
||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||
finished_sequences = self.request_handler.update()
|
||||
if len(finished_sequences) > 0:
|
||||
break
|
||||
|
||||
# Reset back the number of speculated tokens of the batch,
|
||||
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
||||
# by the drafter is less than the number of speculated tokens set to the engine.
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
||||
|
||||
return finished_sequences
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
assert self._initialized, "Engine must be initialized"
|
||||
|
||||
def generate(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||
|
||||
with torch.inference_mode():
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
prompts_token_ids=prompts_token_ids,
|
||||
**gen_config_dict,
|
||||
)
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
self.generation_config_dict = gen_config_dict
|
||||
|
||||
if self.use_spec_dec:
|
||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.steps_spec_dec()
|
||||
else:
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.step()
|
||||
|
||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||
|
||||
for seq in output_seqs_list:
|
||||
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
|
||||
|
||||
if return_token_ids:
|
||||
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
|
||||
return output_str, output_tokens_list
|
||||
else:
|
||||
return output_str
|
||||
|
||||
@property
|
||||
def has_prompt_template(self) -> bool:
|
||||
""" """
|
||||
return self.inference_config.prompt_template is not None
|
||||
|
||||
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
||||
"""
|
||||
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
||||
"""
|
||||
assert (
|
||||
self.has_prompt_template
|
||||
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
||||
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -630,168 +98,36 @@ class InferenceEngine:
|
|||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
kwargs: for LLM, it could be max_length, max_new_tokens, etc
|
||||
for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers
|
||||
"""
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs)
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
def step(self):
|
||||
assert self.engine is not None, "Please init Engine first"
|
||||
return self.engine.step()
|
||||
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
||||
)
|
||||
|
||||
assert (
|
||||
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
||||
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
The Design logic of getattr, setattr:
|
||||
1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine.
|
||||
2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx
|
||||
So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine)
|
||||
"""
|
||||
if self.__dict__.get("_initialized", False):
|
||||
if name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
if prompts == None:
|
||||
prompt = None
|
||||
return getattr(self.engine, name)
|
||||
else:
|
||||
return self.__dict__[name]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if self.__dict__.get("_initialized", False):
|
||||
if name in self.__dict__:
|
||||
self.__dict__[name] = value
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
|
||||
max_length = kwargs.get("max_length", None)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", None)
|
||||
if max_length is None and max_new_tokens is None:
|
||||
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
|
||||
elif max_length is not None:
|
||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||
|
||||
if not self.inference_config.enable_streamingllm:
|
||||
assert (
|
||||
self.inference_config.max_output_len >= max_new_tokens
|
||||
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
||||
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
max_output_len=max_new_tokens,
|
||||
ignore_eos=self.inference_config.ignore_eos,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
||||
input_ids = batch.get_1D_inputs()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
if batch.is_prompts:
|
||||
n_tokens = sequence_lengths.sum().item()
|
||||
setattr(self.engine, name, value)
|
||||
else:
|
||||
n_tokens = batch.current_batch_size
|
||||
if batch.use_spec_dec:
|
||||
n_tokens = batch.num_tokens_to_verify + 1
|
||||
assert n_tokens == input_ids.size(0)
|
||||
n_tokens = n_tokens * batch.current_batch_size
|
||||
output_tensor = torch.zeros(
|
||||
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
|
||||
batch_token_ids = None
|
||||
if (
|
||||
self.generation_config.repetition_penalty != 1.0
|
||||
or self.generation_config.no_repeat_ngram_size > 0
|
||||
or self.generation_config.forced_eos_token_id is not None
|
||||
):
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph = False
|
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||
use_cuda_graph = True
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=batch.get_block_table_tensor(),
|
||||
sequence_lengths=sequence_lengths,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
batch_size=batch.current_batch_size,
|
||||
is_prompts=batch.is_prompts,
|
||||
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
high_precision=self.high_precision,
|
||||
kv_seq_len=sequence_lengths.max().item(),
|
||||
head_dim=batch.head_dim,
|
||||
dtype=batch.dtype,
|
||||
use_spec_dec=batch.use_spec_dec,
|
||||
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||
batch_token_ids=batch_token_ids,
|
||||
)
|
||||
|
||||
return input_ids, output_tensor, input_meta_data
|
||||
|
||||
def step(self) -> List[str]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. Get the input, inputinfo and output placeholder from the batchbucket
|
||||
3. Run model to generate the next token
|
||||
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||
5. Decode and return finished sequences.
|
||||
|
||||
Returns:
|
||||
List[str]: Decoded finished sequences generated by one step.
|
||||
"""
|
||||
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.inference_config.enable_streamingllm:
|
||||
updated_block_ids = batch.streamingllm_update_batch(
|
||||
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
||||
)
|
||||
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
||||
|
||||
next_tokens = search_tokens(
|
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||
)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
||||
self.__dict__[name] = value
|
||||
|
|
|
@ -0,0 +1,758 @@
|
|||
import time
|
||||
from itertools import count
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import distributed as dist
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .base_engine import BaseEngine
|
||||
from .request_handler import RequestHandler
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = {
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"BaichuanForCausalLM": AutoModelForCausalLM,
|
||||
}
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
|
||||
class LLMEngine(BaseEngine):
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
Args:
|
||||
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_path: nn.Module | str,
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None,
|
||||
inference_config: InferenceConfig = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy | type[Policy] = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.dtype = inference_config.dtype
|
||||
self.high_precision = inference_config.high_precision
|
||||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||
# DISCUSS maybe move this into batch info?
|
||||
|
||||
self.counter = count()
|
||||
|
||||
self.use_cuda_graph = self.inference_config.use_cuda_graph
|
||||
if self.use_cuda_graph:
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
self.graph_memory_pool = None # Set during graph capture.
|
||||
if verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture on")
|
||||
|
||||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = self.inference_config.use_spec_dec
|
||||
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def init_model(
|
||||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model.
|
||||
model_inference_config: the configuration for modeling initialization when inference.
|
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
|
||||
"""
|
||||
pretrained_path = None
|
||||
if isinstance(model_or_path, str):
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
if arch in _supported_models.keys():
|
||||
if arch == "BaichuanForCausalLM":
|
||||
self.logger.warning(
|
||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
|
||||
)
|
||||
ctx = LazyInitContext(default_device="cuda")
|
||||
with ctx:
|
||||
model = _supported_models[arch].from_pretrained(
|
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
|
||||
)
|
||||
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||
else:
|
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
|
||||
raise ValueError(f"Model {arch} is not supported.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
)
|
||||
else:
|
||||
model = model_or_path
|
||||
|
||||
self.model_config = model.config
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
model = model.to(self.dtype).eval()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
|
||||
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
|
||||
model_policy = model_policy_map.get(model_policy_key)
|
||||
|
||||
if not isinstance(model_policy, Policy):
|
||||
try:
|
||||
model_policy = model_policy()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to instantiate model policy: {e}")
|
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_shard_infer_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = ModelWrapper(model).to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if pretrained_path:
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(pretrained_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info("Colossal AI CUDA Graph Capture begin")
|
||||
|
||||
t_capture_begin = time.perf_counter()
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
|
||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
|
||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
|
||||
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
|
||||
self.graph_block_tables[0, :] = np.arange(
|
||||
0, max_num_blocks
|
||||
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
output_tensor = torch.zeros(
|
||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
|
||||
)
|
||||
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
|
||||
|
||||
max_num_seqs = self.inference_config.max_batch_size
|
||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
|
||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
|
||||
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
|
||||
sequence_lengths[0] = torch.tensor(
|
||||
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
|
||||
).cuda()
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
if self.verbose:
|
||||
self.logger.info(f"batch size {batch_size} graph capturing")
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=block_tables[:batch_size],
|
||||
sequence_lengths=sequence_lengths[:batch_size],
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
batch_size=batch_size,
|
||||
is_prompts=False,
|
||||
use_cuda_graph=True,
|
||||
high_precision=False,
|
||||
kv_seq_len=sequence_lengths[:batch_size].max().item(),
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens_ids[:batch_size],
|
||||
output_tensor[:batch_size],
|
||||
input_meta_data,
|
||||
k_caches=k_cache,
|
||||
v_caches=v_cache,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
|
||||
t_capture_end = time.perf_counter()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
"""Verify the input args"""
|
||||
if not isinstance(self.inference_config, InferenceConfig):
|
||||
raise TypeError("Invalid type of inference config provided.")
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
if isinstance(self.model, ModelWrapper):
|
||||
model = self.model.module
|
||||
assert (
|
||||
model.__class__.__name__ in _supported_models.keys()
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
|
||||
def enable_spec_dec(
|
||||
self,
|
||||
drafter_model: nn.Module = None,
|
||||
n_spec_tokens: int = None,
|
||||
use_glide_drafter: bool = False,
|
||||
) -> None:
|
||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||
|
||||
Args:
|
||||
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
|
||||
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
||||
If True, the drafter model will be replaced by a glide model.
|
||||
|
||||
```python
|
||||
...
|
||||
engine = InferenceEngine(model, tokenizer, inference_config)
|
||||
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
engine.generate(...) # Speculative Decoding
|
||||
|
||||
engine.disable_spec_dec()
|
||||
engine.generate(...) # Normal generation
|
||||
|
||||
engine.enable_spec_dec()
|
||||
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
|
||||
engine.clear_spec_dec()
|
||||
```
|
||||
"""
|
||||
|
||||
if drafter_model is None and self.drafter is None:
|
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
|
||||
if n_spec_tokens is not None:
|
||||
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
|
||||
self.n_spec_tokens = n_spec_tokens
|
||||
if drafter_model is not None:
|
||||
assert isinstance(drafter_model, nn.Module)
|
||||
# overwrite the drafter, if exists
|
||||
self.clear_spec_dec()
|
||||
self.drafter_model = drafter_model
|
||||
self.drafter = Drafter(
|
||||
self.drafter_model,
|
||||
self.tokenizer,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# check if the provided drafter model is compatible with GLIDE structure
|
||||
# when `use_glide_drafter` is set to True
|
||||
if (
|
||||
use_glide_drafter
|
||||
and hasattr(drafter_model, "model")
|
||||
and hasattr(drafter_model.model, "layers")
|
||||
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
||||
):
|
||||
self.use_glide = use_glide_drafter
|
||||
elif use_glide_drafter:
|
||||
self.logger.warning(
|
||||
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
||||
f"but the provided drafter model is not compatible with GLIDE structure."
|
||||
f"Falling back to use the default drafter model (non-GLIDE)."
|
||||
)
|
||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||
# using speculative decoding for subsequent generations
|
||||
self.use_spec_dec = True
|
||||
|
||||
def disable_spec_dec(self) -> None:
|
||||
"""Disable using speculative decoding for subsequent generations."""
|
||||
self.request_handler.unset_spec_dec_mode()
|
||||
# set back to the maximum number of tokens to speculate
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def clear_spec_dec(self) -> None:
|
||||
"""Clear relatable structures of speculative decoding, if exist."""
|
||||
if self.use_spec_dec:
|
||||
self.disable_spec_dec()
|
||||
if self.drafter_model or self.drafter:
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
torch.cuda.empty_cache()
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]:
|
||||
"""
|
||||
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
|
||||
with many steps of speculating by a drafter model as well as verifying by a main model.
|
||||
|
||||
Returns:
|
||||
List[Sequence]: finished sequences generated by one step.
|
||||
"""
|
||||
batch = self.request_handler.schedule() # prefill batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
||||
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
|
||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
# append new inputs to the batch, temporarily
|
||||
batch.append_batch_tokens(next_tokens)
|
||||
self.request_handler.allocate_batch_spec_dec(batch, 1)
|
||||
already_allocated_kv_len = batch.seq_lengths[0].item()
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(1)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
while True:
|
||||
# HACK Retrieve the running batch
|
||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
|
||||
batch = self.request_handler.running_bb # running batch
|
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens
|
||||
glide_input = None
|
||||
if self.use_glide:
|
||||
glide_input = GlideInput(
|
||||
batch.get_block_table_tensor(),
|
||||
self.k_cache[-1], # use kv cahces of the last layer
|
||||
self.v_cache[-1],
|
||||
batch.get_sequence_lengths(),
|
||||
n_spec_tokens=self.n_spec_tokens,
|
||||
)
|
||||
|
||||
drafter_out = self.drafter.speculate(
|
||||
input_token_ids,
|
||||
self.n_spec_tokens,
|
||||
drafter_past_key_values,
|
||||
glide_input=glide_input,
|
||||
)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
drafter_spec_length = drafter_out.speculated_length
|
||||
|
||||
for next_token_id_spec in next_token_ids_spec:
|
||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
|
||||
cur_length = batch.seq_lengths[0].item()
|
||||
if already_allocated_kv_len < cur_length:
|
||||
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
|
||||
already_allocated_kv_len = cur_length
|
||||
|
||||
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||
if drafter_spec_length < batch.num_tokens_to_verify:
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
|
||||
|
||||
# 5. Compare and process the results
|
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
|
||||
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
|
||||
|
||||
# revoke appended tokens for each Sequence in the current batch
|
||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||
|
||||
# append the last correct token generated by the main model
|
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||
|
||||
# trim past key values of the drafter model
|
||||
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||
)
|
||||
|
||||
# prepare inputs for the next round of speculation
|
||||
n = 1 if n_matches < drafter_spec_length else 2
|
||||
input_token_ids = batch.get_1D_inputs_spec_dec(n)
|
||||
|
||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
|
||||
finished_sequences = self.request_handler.update()
|
||||
if len(finished_sequences) > 0:
|
||||
break
|
||||
|
||||
# Reset back the number of speculated tokens of the batch,
|
||||
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
||||
# by the drafter is less than the number of speculated tokens set to the engine.
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
||||
|
||||
return finished_sequences
|
||||
|
||||
def generate(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
|
||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
|
||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
prompts = [prompts] if isinstance(prompts, str) else prompts
|
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
|
||||
|
||||
with torch.inference_mode():
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(
|
||||
request_ids=request_ids,
|
||||
prompts=prompts,
|
||||
prompts_token_ids=prompts_token_ids,
|
||||
**gen_config_dict,
|
||||
)
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
self.generation_config_dict = gen_config_dict
|
||||
|
||||
if self.use_spec_dec:
|
||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||
while self.request_handler.check_unfinished_reqs():
|
||||
output_seqs_list += self.steps_spec_dec()
|
||||
else:
|
||||
while self.request_handler.check_unfinished_reqs():
|
||||
output_seqs_list += self.step()
|
||||
|
||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||
|
||||
for seq in output_seqs_list:
|
||||
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
|
||||
|
||||
if return_token_ids:
|
||||
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
|
||||
return output_str, output_tokens_list
|
||||
else:
|
||||
return output_str
|
||||
|
||||
@property
|
||||
def has_prompt_template(self) -> bool:
|
||||
""" """
|
||||
return self.inference_config.prompt_template is not None
|
||||
|
||||
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
||||
"""
|
||||
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
||||
"""
|
||||
assert (
|
||||
self.has_prompt_template
|
||||
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
||||
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Add requests.
|
||||
|
||||
Args:
|
||||
request_ids (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
"""
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
|
||||
)
|
||||
|
||||
assert (
|
||||
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
|
||||
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
assert len(request_ids) == prompts_num
|
||||
request_id = request_ids[i]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
if prompts == None:
|
||||
prompt = None
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
|
||||
max_length = kwargs.get("max_length", None)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", None)
|
||||
if max_length is None and max_new_tokens is None:
|
||||
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
|
||||
elif max_length is not None:
|
||||
max_new_tokens = max_length - len(prompts_token_ids[i])
|
||||
|
||||
if not self.inference_config.enable_streamingllm:
|
||||
assert (
|
||||
self.inference_config.max_output_len >= max_new_tokens
|
||||
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
|
||||
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
max_output_len=max_new_tokens,
|
||||
ignore_eos=self.inference_config.ignore_eos,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
|
||||
input_ids = batch.get_1D_inputs()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
|
||||
if batch.is_prompts:
|
||||
n_tokens = sequence_lengths.sum().item()
|
||||
else:
|
||||
n_tokens = batch.current_batch_size
|
||||
if batch.use_spec_dec:
|
||||
n_tokens = batch.num_tokens_to_verify + 1
|
||||
assert n_tokens == input_ids.size(0)
|
||||
n_tokens = n_tokens * batch.current_batch_size
|
||||
output_tensor = torch.zeros(
|
||||
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
|
||||
batch_token_ids = None
|
||||
if (
|
||||
self.generation_config.repetition_penalty != 1.0
|
||||
or self.generation_config.no_repeat_ngram_size > 0
|
||||
or self.generation_config.forced_eos_token_id is not None
|
||||
):
|
||||
batch_token_ids = batch.batch_token_ids
|
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
use_cuda_graph = False
|
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
|
||||
use_cuda_graph = True
|
||||
|
||||
input_meta_data = InputMetaData(
|
||||
block_tables=batch.get_block_table_tensor(),
|
||||
sequence_lengths=sequence_lengths,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
batch_size=batch.current_batch_size,
|
||||
is_prompts=batch.is_prompts,
|
||||
use_cuda_kernel=self.inference_config.use_cuda_kernel,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
high_precision=self.high_precision,
|
||||
kv_seq_len=sequence_lengths.max().item(),
|
||||
head_dim=batch.head_dim,
|
||||
dtype=batch.dtype,
|
||||
use_spec_dec=batch.use_spec_dec,
|
||||
num_tokens_to_verify=batch.num_tokens_to_verify,
|
||||
batch_token_ids=batch_token_ids,
|
||||
)
|
||||
|
||||
return input_ids, output_tensor, input_meta_data
|
||||
|
||||
def step(self) -> List[str]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. Get the input, inputinfo and output placeholder from the batchbucket
|
||||
3. Run model to generate the next token
|
||||
4. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||
5. Decode and return finished sequences.
|
||||
|
||||
Returns:
|
||||
List[str]: Decoded finished sequences generated by one step.
|
||||
"""
|
||||
|
||||
batch = self.request_handler.schedule()
|
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||
|
||||
if input_meta_data.use_cuda_graph:
|
||||
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.inference_config.enable_streamingllm:
|
||||
updated_block_ids = batch.streamingllm_update_batch(
|
||||
self.inference_config.start_token_size, self.inference_config.generated_token_size
|
||||
)
|
||||
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
|
||||
|
||||
next_tokens = search_tokens(
|
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||
)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
|
@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
|||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
||||
from colossalai.inference.struct import RequestStatus, Sequence
|
||||
from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
@ -98,7 +98,46 @@ class RunningList:
|
|||
self._decoding[seq_id] = self._prefill.pop(seq_id)
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
class NaiveRequestHandler:
|
||||
def __init__(self) -> None:
|
||||
self.running_list: List[DiffusionSequence] = []
|
||||
self.waiting_list: List[str] = []
|
||||
|
||||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def _has_running(self) -> bool:
|
||||
return any(lst for lst in self.running_list)
|
||||
|
||||
def check_unfinished_reqs(self):
|
||||
return self._has_waiting() or self._has_running()
|
||||
|
||||
def add_sequence(self, seq: DiffusionSequence):
|
||||
"""
|
||||
Add the request to waiting list.
|
||||
"""
|
||||
assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
|
||||
self.waiting_list.append(seq)
|
||||
|
||||
def _find_sequence(self, request_id: int) -> DiffusionSequence:
|
||||
"""
|
||||
Find the request by request_id.
|
||||
"""
|
||||
for lst in enumerate(self.waiting_list + self.running_list):
|
||||
for seq in lst:
|
||||
if seq.request_id == request_id:
|
||||
return seq
|
||||
return None
|
||||
|
||||
def schedule(self):
|
||||
ret = None
|
||||
if self._has_waiting:
|
||||
ret = self.waiting_list[0]
|
||||
self.waiting_list = self.waiting_list[1:]
|
||||
return ret
|
||||
|
||||
|
||||
class RequestHandler(NaiveRequestHandler):
|
||||
"""
|
||||
RequestHandler is the core for handling existing requests and updating current batch.
|
||||
During generation process, we call schedule function each iteration to update current batch.
|
||||
|
@ -176,12 +215,12 @@ class RequestHandler:
|
|||
generated_token_size=inference_config.generated_token_size,
|
||||
)
|
||||
|
||||
def _has_running(self) -> bool:
|
||||
return not self.running_bb.is_empty()
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||
|
||||
def _has_waiting(self) -> bool:
|
||||
return any(lst for lst in self.waiting_list)
|
||||
|
||||
def get_kvcache(self):
|
||||
return self.cache_manager.get_kv_cache()
|
||||
|
||||
|
@ -318,7 +357,7 @@ class RequestHandler:
|
|||
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
|
||||
seq.mark_finished()
|
||||
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
def check_unfinished_reqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import inspect
|
||||
import types
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DiffusionPipe(nn.Module):
|
||||
"""
|
||||
This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property.
|
||||
"""
|
||||
|
||||
def __init__(self, source_obj) -> None:
|
||||
super(DiffusionPipe, self).__init__()
|
||||
|
||||
for k, v in source_obj.__dict__.items():
|
||||
if isinstance(v, nn.Module):
|
||||
self.add_module(k, v)
|
||||
else:
|
||||
setattr(self, k, v)
|
||||
|
||||
skip_list = ["_execution_device", "to", "device"] # this
|
||||
|
||||
for name, member in inspect.getmembers(source_obj.__class__):
|
||||
if name in skip_list:
|
||||
continue
|
||||
if not name.startswith("__") and not name.endswith("__"):
|
||||
if isinstance(member, property):
|
||||
setattr(self.__class__, name, member)
|
||||
elif inspect.isfunction(member) or inspect.ismethod(member):
|
||||
bound_method = types.MethodType(member, self)
|
||||
setattr(self, name, bound_method)
|
||||
elif not callable(member) and not isinstance(member, property):
|
||||
setattr(self, name, member)
|
||||
elif name == "__call__":
|
||||
bound_method = types.MethodType(member, self)
|
||||
setattr(self, "_forward", bound_method)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
|
||||
Accelerate's module hooks.
|
||||
"""
|
||||
# return self.device
|
||||
return torch.device("cuda")
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
next(self.parameters()).device
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward(*args, **kwargs)
|
|
@ -0,0 +1,220 @@
|
|||
# Code adapted from:
|
||||
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
|
||||
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
|
||||
ASPECT_RATIO_256_BIN,
|
||||
ASPECT_RATIO_512_BIN,
|
||||
ASPECT_RATIO_1024_BIN,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .diffusion import DiffusionPipe
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def pixart_alpha_forward(
|
||||
self: DiffusionPipe,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 20,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
guidance_scale: float = 4.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
use_resolution_binning: bool = True,
|
||||
max_sequence_length: int = 120,
|
||||
**kwargs,
|
||||
) -> PIL.Image:
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
if use_resolution_binning:
|
||||
if self.transformer.config.sample_size == 128:
|
||||
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
||||
elif self.transformer.config.sample_size == 64:
|
||||
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
||||
elif self.transformer.config.sample_size == 32:
|
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||
else:
|
||||
raise ValueError("Invalid sample size")
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_steps,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Prepare micro-conditions.
|
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||
if self.transformer.config.sample_size == 128:
|
||||
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
||||
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
||||
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
|
||||
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
resolution = torch.cat([resolution, resolution], dim=0)
|
||||
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
||||
|
||||
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
current_timestep = t
|
||||
if not torch.is_tensor(current_timestep):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = latent_model_input.device.type == "mps"
|
||||
if isinstance(current_timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
||||
elif len(current_timestep.shape) == 0:
|
||||
current_timestep = current_timestep[None].to(latent_model_input.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=current_timestep,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# learned sigma
|
||||
if self.transformer.config.out_channels // 2 == latent_channels:
|
||||
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||
else:
|
||||
noise_pred = noise_pred
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
if num_inference_steps == 1:
|
||||
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
if use_resolution_binning:
|
||||
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
# self.maybe_free_model_hooks()
|
||||
|
||||
return image
|
|
@ -0,0 +1,178 @@
|
|||
# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
||||
|
||||
from .diffusion import DiffusionPipe
|
||||
|
||||
|
||||
# TODO(@lry89757) temporarily image, please support more return output
|
||||
@torch.no_grad()
|
||||
def sd3_forward(
|
||||
self: DiffusionPipe,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_3=prompt_3,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
clip_skip=self.clip_skip,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
return image
|
|
@ -1,16 +1,22 @@
|
|||
from .glide_llama import GlideLlamaModelPolicy
|
||||
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
|
||||
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
|
||||
from .pixart_alpha import PixArtAlphaInferPolicy
|
||||
from .stablediffusion3 import StableDiffusion3InferPolicy
|
||||
|
||||
model_policy_map = {
|
||||
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
|
||||
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
|
||||
"glide_llama": GlideLlamaModelPolicy,
|
||||
"StableDiffusion3Pipeline": StableDiffusion3InferPolicy,
|
||||
"PixArtAlphaPipeline": PixArtAlphaInferPolicy,
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"NoPaddingLlamaModelInferPolicy",
|
||||
"NoPaddingBaichuanModelInferPolicy",
|
||||
"GlideLlamaModelPolicy",
|
||||
"StableDiffusion3InferPolicy",
|
||||
"PixArtAlphaInferPolicy",
|
||||
"model_polic_map",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
from torch import nn
|
||||
|
||||
from colossalai.inference.config import RPC_PARAM
|
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
|
||||
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = {}
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
|
||||
)
|
||||
return policy
|
||||
|
||||
def preprocess(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def to_rpc_param(self) -> str:
|
||||
return __class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def from_rpc_param() -> "PixArtAlphaInferPolicy":
|
||||
return PixArtAlphaInferPolicy()
|
|
@ -0,0 +1,34 @@
|
|||
from torch import nn
|
||||
|
||||
from colossalai.inference.config import RPC_PARAM
|
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
|
||||
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = {}
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
|
||||
)
|
||||
return policy
|
||||
|
||||
def preprocess(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def to_rpc_param(self) -> str:
|
||||
return __class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def from_rpc_param() -> "StableDiffusion3InferPolicy":
|
||||
return StableDiffusion3InferPolicy()
|
|
@ -2,6 +2,7 @@ import enum
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
from colossalai.inference.config import DiffusionGenerationConfig
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
@ -46,6 +47,17 @@ class RequestStatus(enum.Enum):
|
|||
return status == RequestStatus.WAITING
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionSequence:
|
||||
"""
|
||||
parameters for diffusion
|
||||
"""
|
||||
|
||||
request_id: int
|
||||
prompt: str
|
||||
generation_config: DiffusionGenerationConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sequence:
|
||||
"""Store information of input sequence.
|
||||
|
|
|
@ -5,10 +5,12 @@ Utils for model inference
|
|||
import math
|
||||
import os
|
||||
import re
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from torch import nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
@ -159,3 +161,38 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
|
|||
except ImportError:
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
return False
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
DIFFUSION_MODEL = "Diffusion Model"
|
||||
LLM = "Large Language Model (LLM)"
|
||||
UNKNOWN = "Unknown Model Type"
|
||||
|
||||
|
||||
def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
|
||||
if isinstance(model_or_path, DiffusionPipeline):
|
||||
return ModelType.DIFFUSION_MODEL
|
||||
elif isinstance(model_or_path, nn.Module):
|
||||
return ModelType.LLM
|
||||
elif isinstance(model_or_path, str):
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
return ModelType.LLM
|
||||
except:
|
||||
"""
|
||||
model type is not `ModelType.LLM`
|
||||
"""
|
||||
|
||||
try:
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
DiffusionPipeline.load_config(model_or_path)
|
||||
return ModelType.DIFFUSION_MODEL
|
||||
except:
|
||||
"""
|
||||
model type is not `ModelType.DIFFUSION_MODEL`
|
||||
"""
|
||||
else:
|
||||
return ModelType.UNKNOWN
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
import argparse
|
||||
|
||||
from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
|
||||
from torch import bfloat16, float16, float32
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
|
||||
from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
|
||||
|
||||
# For Stable Diffusion 3, we'll use the following configuration
|
||||
MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
|
||||
POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
|
||||
|
||||
TORCH_DTYPE_MAP = {
|
||||
"fp16": float16,
|
||||
"fp32": float32,
|
||||
"bf16": bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def infer(args):
|
||||
# ==============================
|
||||
# Launch colossalai, setup distributed environment
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Load model and tokenizer
|
||||
# ==============================
|
||||
model_path_or_name = args.model
|
||||
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
|
||||
|
||||
# ==============================
|
||||
# Initialize InferenceEngine
|
||||
# ==============================
|
||||
coordinator.print_on_master(f"Initializing Inference Engine...")
|
||||
inference_config = InferenceConfig(
|
||||
dtype=args.dtype,
|
||||
max_batch_size=args.max_batch_size,
|
||||
tp_size=args.tp_size,
|
||||
use_cuda_kernel=args.use_cuda_kernel,
|
||||
)
|
||||
engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
|
||||
|
||||
# ==============================
|
||||
# Generation
|
||||
# ==============================
|
||||
coordinator.print_on_master(f"Generating...")
|
||||
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
|
||||
out.save("cat.jpg")
|
||||
coordinator.print_on_master(out)
|
||||
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
|
||||
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
|
||||
parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt")
|
||||
parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
|
||||
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
|
||||
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
|
||||
args = parser.parse_args()
|
||||
|
||||
infer(args)
|
|
@ -23,3 +23,4 @@ rpyc==6.0.0
|
|||
fastapi
|
||||
uvicorn==0.29.0
|
||||
galore_torch
|
||||
diffusers==0.29.0
|
||||
|
|
Loading…
Reference in New Issue