Browse Source

[Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support
pull/5894/head
Runyu Lu 5 months ago committed by GitHub
parent
commit
cba20525a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 48
      colossalai/inference/config.py
  2. 90
      colossalai/inference/core/base_engine.py
  3. 200
      colossalai/inference/core/diffusion_engine.py
  4. 800
      colossalai/inference/core/engine.py
  5. 758
      colossalai/inference/core/llm_engine.py
  6. 51
      colossalai/inference/core/request_handler.py
  7. 54
      colossalai/inference/modeling/models/diffusion.py
  8. 220
      colossalai/inference/modeling/models/pixart_alpha.py
  9. 178
      colossalai/inference/modeling/models/stablediffusion3.py
  10. 6
      colossalai/inference/modeling/policy/__init__.py
  11. 34
      colossalai/inference/modeling/policy/pixart_alpha.py
  12. 34
      colossalai/inference/modeling/policy/stablediffusion3.py
  13. 12
      colossalai/inference/struct.py
  14. 39
      colossalai/inference/utils.py
  15. 75
      examples/inference/stable_diffusion/sd3_generation.py
  16. 1
      requirements/requirements.txt

48
colossalai/inference/config.py

@ -5,7 +5,7 @@ Our config contains various options for inference optimization, it is a unified
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers.generation import GenerationConfig
@ -396,3 +396,49 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
@dataclass
class DiffusionGenerationConfig:
"""
Param for diffusion model forward
"""
prompt_2: Optional[Union[str, List[str]]] = None
prompt_3: Optional[Union[str, List[str]]] = None
height: Optional[int] = None
width: Optional[int] = None
num_inference_steps: int = None
timesteps: List[int] = None
guidance_scale: float = None
negative_prompt: Optional[Union[str, List[str]]] = (
None # NOTE(@lry89757) in pixart default to "", in sd3 default to None
)
negative_prompt_2: Optional[Union[str, List[str]]] = None
negative_prompt_3: Optional[Union[str, List[str]]] = None
num_images_per_prompt: Optional[int] = None
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None
latents: Optional[torch.FloatTensor] = None
prompt_embeds: Optional[torch.FloatTensor] = None
negative_prompt_embeds: Optional[torch.FloatTensor] = None
pooled_prompt_embeds: Optional[torch.FloatTensor] = None
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None
output_type: Optional[str] = None # "pil"
return_dict: bool = None
joint_attention_kwargs: Optional[Dict[str, Any]] = None
clip_skip: Optional[int] = None
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None
callback_on_step_end_tensor_inputs: List[str] = None
def to_dict(self) -> Dict[str, Any]:
# NOTE(@lry89757) Only return the dict that not the default value None
result = {}
for field in fields(self):
value = getattr(self, field.name)
if value is not None:
result[field.name] = value
return result
@classmethod
def from_kwargs(cls, **kwargs) -> "DiffusionGenerationConfig":
return cls(**kwargs)

90
colossalai/inference/core/base_engine.py

@ -0,0 +1,90 @@
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
class BaseEngine(ABC):
@abstractmethod
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None):
pass
@abstractmethod
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None):
"""
Init Model for Engine
"""
@abstractmethod
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs):
"""
Generate ouptput for coming requests
"""
@abstractmethod
def add_request(self, prompts, request_ids=None, **kwargs):
"""
Add new request to Engine
"""
@abstractmethod
def step(self):
"""
Perform one new step forward
"""
@abstractmethod
def _verify_args(self):
"""
Verify the parameters and members of class
"""
@torch.inference_mode()
def capture_model(self):
"""
Use cuda graph to capture model
"""
return NotImplementedError("This method should be implemented by subclasses")
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
**kwargs,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: The model optimized by Shardformer.
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model

200
colossalai/inference/core/diffusion_engine.py

@ -0,0 +1,200 @@
from itertools import count
from typing import List, Tuple, Type, Union
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from torch import distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import DiffusionSequence
from colossalai.inference.utils import get_model_size, get_model_type
from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy
from .base_engine import BaseEngine
from .request_handler import NaiveRequestHandler
PP_AXIS, TP_AXIS = 0, 1
class DiffusionEngine(BaseEngine):
def __init__(
self,
model_or_path: DiffusionPipeline | str,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Policy | type[Policy] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
self.high_precision = inference_config.high_precision
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.model_type = get_model_type(model_or_path=model_or_path)
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.request_handler = NaiveRequestHandler()
self.counter = count()
self._verify_args()
def _verify_args(self) -> None:
assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe"
def init_model(
self,
model_or_path: Union[str, nn.Module, DiffusionPipeline],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
if isinstance(model_or_path, str):
model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype)
policy_map_key = model.__class__.__name__
model = DiffusionPipe(model)
elif isinstance(model_or_path, DiffusionPipeline):
policy_map_key = model_or_path.__class__.__name__
model = DiffusionPipe(model_or_path)
else:
self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!")
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")
if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)
if model_policy is None:
model_policy = model_policy_map.get(policy_map_key)
if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
self.model = model.to(self.device)
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
generation_config: DiffusionGenerationConfig = None,
**kwargs,
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
""" """
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode():
if prompts is not None:
self.add_request(
request_ids=request_ids,
prompts=prompts,
**gen_config_dict,
**kwargs,
)
output_reqs_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
while self.request_handler.check_unfinished_reqs():
output_reqs_list += self.step()
return output_reqs_list
def add_request(
self,
prompts: Union[List[str], str],
request_ids: Union[List[int], int] = None,
**kwargs,
):
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if not isinstance(prompts, list):
prompts = [prompts]
generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs)
prompts_num = len(prompts)
for i in range(prompts_num):
if request_ids:
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
else:
request_id = next(self.counter)
seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config)
self.request_handler.add_sequence(seq)
def step(self) -> List[PIL.Image.Image]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. run forward to get List[Image]
Returns:
List[PIL.Image.Image]: Image Generated by one step.
"""
input = self.request_handler.schedule()
ret = self.model(prompt=input.prompt, **input.generation_config.to_dict())
return ret

800
colossalai/inference/core/engine.py

@ -1,57 +1,24 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import List, Tuple, Type, Union
import numpy as np
import torch
import PIL.Image
import torch.nn as nn
from torch import distributed as dist
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from diffusers import DiffusionPipeline
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.inference.config import InferenceConfig
from colossalai.inference.utils import ModelType, get_model_type
from colossalai.shardformer.policies.base_policy import Policy
from .request_handler import RequestHandler
__all__ = ["InferenceEngine"]
PP_AXIS, TP_AXIS = 0, 1
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
"BaichuanForCausalLM": AutoModelForCausalLM,
}
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class InferenceEngine:
"""
InferenceEngine which manages the inference process..
Args:
model_or_path (nn.Module or str): Path or nn.Module of this model.
model_or_path (nn.Module or DiffusionPipeline or str): Path or nn.Module or DiffusionPipeline of this model.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
@ -60,567 +27,68 @@ class InferenceEngine:
def __init__(
self,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
model_or_path: Union[nn.Module, str, DiffusionPipeline],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Union[Policy, Type[Policy]] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
self.high_precision = inference_config.high_precision
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
self.counter = count()
self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")
self.capture_model(self.k_cache, self.v_cache)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = self.inference_config.use_spec_dec
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self._verify_args()
def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
pretrained_path = None
if isinstance(model_or_path, str):
import colossalai.interface.pretrained as pretrained_utils
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys():
if arch is "BaichuanForCausalLM":
self.logger.warning(
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
)
ctx = LazyInitContext(default_device="cuda")
with ctx:
model = _supported_models[arch].from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
pretrained_path = pretrained_utils.get_pretrained_path(model)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
self.logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
)
else:
model = model_or_path
self.model_config = model.config
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")
model = model.to(self.dtype).eval()
if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
self.__dict__["_initialized"] = False # use __dict__ directly to avoid calling __setattr__
self.model_type = get_model_type(model_or_path=model_or_path)
self.engine = None
if self.model_type == ModelType.LLM:
from .llm_engine import LLMEngine
self.engine = LLMEngine(
model_or_path=model_or_path,
tokenizer=tokenizer,
inference_config=inference_config,
verbose=verbose,
model_policy=model_policy,
)
if model_policy is None:
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
model_policy = model_policy_map.get(model_policy_key)
if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
self.model = ModelWrapper(model).to(self.device)
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
elif self.model_type == ModelType.DIFFUSION_MODEL:
from .diffusion_engine import DiffusionEngine
self.engine = DiffusionEngine(
model_or_path=model_or_path,
inference_config=inference_config,
verbose=verbose,
model_policy=model_policy,
)
elif self.model_type == ModelType.UNKNOWN:
self.logger.error(f"Model Type either Difffusion or LLM!")
if pretrained_path:
from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(pretrained_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
@torch.inference_mode()
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
assert self.use_cuda_graph, "please turn on the cuda graph"
if self.verbose:
self.logger.info("Colossal AI CUDA Graph Capture begin")
t_capture_begin = time.perf_counter()
block_size = self.inference_config.block_size
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
self.graph_block_tables[0, :] = np.arange(
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
sequence_lengths[0] = torch.tensor(
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")
input_meta_data = InputMetaData(
block_tables=block_tables[:batch_size],
sequence_lengths=sequence_lengths[:batch_size],
fd_inter_tensor=fd_inter_tensor,
batch_size=batch_size,
is_prompts=False,
use_cuda_graph=True,
high_precision=False,
kv_seq_len=sequence_lengths[:batch_size].max().item(),
head_dim=head_dim,
dtype=self.dtype,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids[:batch_size],
output_tensor[:batch_size],
input_meta_data,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
t_capture_end = time.perf_counter()
if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
self._initialized = True
self._verify_args()
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
if isinstance(self.model, ModelWrapper):
model = self.model.module
assert (
model.__class__.__name__ in _supported_models.keys()
), f"Model {self.model.__class__.__name__} is not supported."
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
model_shard_infer_config: ModelShardInferenceConfig = None,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: The model optimized by Shardformer.
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"model_shard_infer_config": model_shard_infer_config},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model
def enable_spec_dec(
self,
drafter_model: nn.Module = None,
n_spec_tokens: int = None,
use_glide_drafter: bool = False,
) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
If True, the drafter model will be replaced by a glide model.
```python
...
engine = InferenceEngine(model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
engine.generate(...) # Speculative Decoding
engine.disable_spec_dec()
engine.generate(...) # Normal generation
engine.enable_spec_dec()
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
self.n_spec_tokens = n_spec_tokens
if drafter_model is not None:
assert isinstance(drafter_model, nn.Module)
# overwrite the drafter, if exists
self.clear_spec_dec()
self.drafter_model = drafter_model
self.drafter = Drafter(
self.drafter_model,
self.tokenizer,
device=self.device,
dtype=self.dtype,
)
# check if the provided drafter model is compatible with GLIDE structure
# when `use_glide_drafter` is set to True
if (
use_glide_drafter
and hasattr(drafter_model, "model")
and hasattr(drafter_model.model, "layers")
and hasattr(drafter_model.model.layers[0], "cross_attn")
):
self.use_glide = use_glide_drafter
elif use_glide_drafter:
self.logger.warning(
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
f"but the provided drafter model is not compatible with GLIDE structure."
f"Falling back to use the default drafter model (non-GLIDE)."
)
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
self.request_handler.unset_spec_dec_mode()
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_glide = False
self.use_spec_dec = False
def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.use_spec_dec:
self.disable_spec_dec()
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_glide = False
self.use_spec_dec = False
def steps_spec_dec(self) -> List[Sequence]:
"""
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
with many steps of speculating by a drafter model as well as verifying by a main model.
Returns:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
already_allocated_kv_len = batch.seq_lengths[0].item()
input_token_ids = batch.get_1D_inputs_spec_dec(1)
finished_sequences = self.request_handler.update()
while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
# 3. Decoding - Drafter model speculates `n` tokens
glide_input = None
if self.use_glide:
glide_input = GlideInput(
batch.get_block_table_tensor(),
self.k_cache[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
n_spec_tokens=self.n_spec_tokens,
)
drafter_out = self.drafter.speculate(
input_token_ids,
self.n_spec_tokens,
drafter_past_key_values,
glide_input=glide_input,
)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
drafter_spec_length = drafter_out.speculated_length
for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
cur_length = batch.seq_lengths[0].item()
if already_allocated_kv_len < cur_length:
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
already_allocated_kv_len = cur_length
# 4. Decoding - Main model verifies `n` tokens in parallel
if drafter_spec_length < batch.num_tokens_to_verify:
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(
drafter_past_key_values, drafter_spec_length - n_matches - 1
)
# prepare inputs for the next round of speculation
n = 1 if n_matches < drafter_spec_length else 2
input_token_ids = batch.get_1D_inputs_spec_dec(n)
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break
# Reset back the number of speculated tokens of the batch,
# this is used to handle the last round of speculation, in which case the number of speculated tokens
# by the drafter is less than the number of speculated tokens set to the engine.
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
return finished_sequences
assert self.engine is not None, "Please init Engine first"
assert self._initialized, "Engine must be initialized"
def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
*args,
**kwargs,
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None:
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = []
total_tokens_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.steps_spec_dec()
else:
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
for seq in output_seqs_list:
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
if return_token_ids:
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
return output_str, output_tokens_list
else:
return output_str
@property
def has_prompt_template(self) -> bool:
""" """
return self.inference_config.prompt_template is not None
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
"""
This method will format the input prompt according to the prompt template given to the InferenceConfig.
"""
assert (
self.has_prompt_template
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str):
return self.inference_config.prompt_template.format(input_text=prompts)
else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
assert self.engine is not None, "Please init Engine first"
return self.engine.generate(request_ids=request_ids, prompts=prompts, *args, **kwargs)
def add_request(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
*args,
**kwargs,
) -> None:
"""
@ -630,168 +98,36 @@ class InferenceEngine:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
kwargs: for LLM, it could be max_length, max_new_tokens, etc
for diffusion, it could be prompt_2, prompt_3, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, negative_prompt_3, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, clip_skip, which aligns with diffusers
"""
assert self.engine is not None, "Please init Engine first"
self.engine.add_request(request_ids=request_ids, prompts=prompts, *args, **kwargs)
# apply the prompt template to the input prompts
if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts)
block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
# list of torch Tensor
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
raise TypeError(
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
)
assert (
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
prompts_num = len(prompts_token_ids)
for i in range(prompts_num):
if request_ids:
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
else:
request_id = next(self.counter)
if prompts == None:
prompt = None
else:
prompt = prompts[i]
max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
def step(self):
assert self.engine is not None, "Please init Engine first"
return self.engine.step()
if not self.inference_config.enable_streamingllm:
assert (
self.inference_config.max_output_len >= max_new_tokens
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
max_output_len=max_new_tokens,
ignore_eos=self.inference_config.ignore_eos,
)
self.request_handler.add_sequence(sequence)
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
n_tokens = sequence_lengths.sum().item()
else:
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
n_tokens = n_tokens * batch.current_batch_size
output_tensor = torch.zeros(
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
batch_token_ids = None
if (
self.generation_config.repetition_penalty != 1.0
or self.generation_config.no_repeat_ngram_size > 0
or self.generation_config.forced_eos_token_id is not None
):
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=batch.fd_inter_tensor,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
def __getattr__(self, name):
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Get the input, inputinfo and output placeholder from the batchbucket
3. Run model to generate the next token
4. Update waiting list and running list in RequestHandler and get finished sequences.
5. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
The Design logic of getattr, setattr:
1. Since InferenceEngine is a wrapper for DiffusionEngine/LLMEngine, we hope to invoke all the member of DiffusionEngine/LLMEngine like we just call the member of InferenceEngine.
2. When we call the __init__ of InferenceEngine, we don't want to setattr using self.__dict__["xxx"] = xxx, we want to use origin ways like self.xxx = xxx
So we set the attribute `_initialized`. And after initialized, if we couldn't get the member from InferenceEngine, we will try to get the member from self.engine(DiffusionEngine/LLMEngine)
"""
batch = self.request_handler.schedule()
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
if self.__dict__.get("_initialized", False):
if name in self.__dict__:
return self.__dict__[name]
else:
return getattr(self.engine, name)
else:
model_executable = self.model
return self.__dict__[name]
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
if self.inference_config.enable_streamingllm:
updated_block_ids = batch.streamingllm_update_batch(
self.inference_config.start_token_size, self.inference_config.generated_token_size
)
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences
def __setattr__(self, name, value):
if self.__dict__.get("_initialized", False):
if name in self.__dict__:
self.__dict__[name] = value
else:
setattr(self.engine, name, value)
else:
self.__dict__[name] = value

758
colossalai/inference/core/llm_engine.py

@ -0,0 +1,758 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
import torch.nn as nn
from torch import distributed as dist
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.sampler import search_tokens
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyInitContext
from colossalai.logging import get_dist_logger
from colossalai.shardformer.policies.base_policy import Policy
from .base_engine import BaseEngine
from .request_handler import RequestHandler
PP_AXIS, TP_AXIS = 0, 1
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
"BaichuanForCausalLM": AutoModelForCausalLM,
}
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class LLMEngine(BaseEngine):
"""
InferenceEngine which manages the inference process..
Args:
model_or_path (nn.Module or str): Path or nn.Module of this model.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
"""
def __init__(
self,
model_or_path: nn.Module | str,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None,
inference_config: InferenceConfig = None,
verbose: bool = False,
model_policy: Policy | type[Policy] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
self.high_precision = inference_config.high_precision
self.verbose = verbose
self.logger = get_dist_logger(__name__)
self.model_shard_infer_config = inference_config.to_model_shard_inference_config()
self.init_model(model_or_path, model_policy, self.model_shard_infer_config)
self.generation_config = inference_config.to_generation_config(self.model_config)
self.generation_config_dict = self.generation_config.to_dict()
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
self.counter = count()
self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")
self.capture_model(self.k_cache, self.v_cache)
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = self.inference_config.use_spec_dec
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self._verify_args()
def init_model(
self,
model_or_path: Union[nn.Module, str],
model_policy: Union[Policy, Type[Policy]] = None,
model_shard_infer_config: ModelShardInferenceConfig = None,
):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model.
model_inference_config: the configuration for modeling initialization when inference.
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference.
"""
pretrained_path = None
if isinstance(model_or_path, str):
import colossalai.interface.pretrained as pretrained_utils
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype)
arch = getattr(hf_config, "architectures")[0]
if arch in _supported_models.keys():
if arch == "BaichuanForCausalLM":
self.logger.warning(
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers"
)
ctx = LazyInitContext(default_device="cuda")
with ctx:
model = _supported_models[arch].from_pretrained(
model_or_path, trust_remote_code=True, torch_dtype=self.dtype
)
pretrained_path = pretrained_utils.get_pretrained_path(model)
else:
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate
raise ValueError(f"Model {arch} is not supported.")
except Exception as e:
self.logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
)
else:
model = model_or_path
self.model_config = model.config
torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]
self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")
model = model.to(self.dtype).eval()
if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)
if model_policy is None:
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
model_policy = model_policy_map.get(model_policy_key)
if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
self.model = self._shardformer(
model,
model_policy,
model_shard_infer_config,
None,
tp_group=tp_group,
)
self.model = ModelWrapper(model).to(self.device)
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)
if pretrained_path:
from colossalai.inference.core.plugin import InferCheckpoint_io
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(pretrained_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)
free_gpu_memory, _ = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)
@torch.inference_mode()
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
assert self.use_cuda_graph, "please turn on the cuda graph"
if self.verbose:
self.logger.info("Colossal AI CUDA Graph Capture begin")
t_capture_begin = time.perf_counter()
block_size = self.inference_config.block_size
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
max_context_len_to_capture = self.inference_config.max_context_len_to_capture
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda()
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32)
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE))
self.graph_block_tables[0, :] = np.arange(
0, max_num_blocks
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
output_tensor = torch.zeros(
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device
)
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor
max_num_seqs = self.inference_config.max_batch_size
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs]
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda()
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len
sequence_lengths[0] = torch.tensor(
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32
).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
if self.verbose:
self.logger.info(f"batch size {batch_size} graph capturing")
input_meta_data = InputMetaData(
block_tables=block_tables[:batch_size],
sequence_lengths=sequence_lengths[:batch_size],
fd_inter_tensor=fd_inter_tensor,
batch_size=batch_size,
is_prompts=False,
use_cuda_graph=True,
high_precision=False,
kv_seq_len=sequence_lengths[:batch_size].max().item(),
head_dim=head_dim,
dtype=self.dtype,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens_ids[:batch_size],
output_tensor[:batch_size],
input_meta_data,
k_caches=k_cache,
v_caches=v_cache,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
t_capture_end = time.perf_counter()
if self.verbose:
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s")
def _verify_args(self) -> None:
"""Verify the input args"""
if not isinstance(self.inference_config, InferenceConfig):
raise TypeError("Invalid type of inference config provided.")
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}")
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
if isinstance(self.model, ModelWrapper):
model = self.model.module
assert (
model.__class__.__name__ in _supported_models.keys()
), f"Model {self.model.__class__.__name__} is not supported."
def enable_spec_dec(
self,
drafter_model: nn.Module = None,
n_spec_tokens: int = None,
use_glide_drafter: bool = False,
) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
Args:
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens.
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
If True, the drafter model will be replaced by a glide model.
```python
...
engine = InferenceEngine(model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
engine.generate(...) # Speculative Decoding
engine.disable_spec_dec()
engine.generate(...) # Normal generation
engine.enable_spec_dec()
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens
engine.clear_spec_dec()
```
"""
if drafter_model is None and self.drafter is None:
raise ValueError("Drafter not initialized. Please provide a Drafter Model")
if n_spec_tokens is not None:
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens
self.n_spec_tokens = n_spec_tokens
if drafter_model is not None:
assert isinstance(drafter_model, nn.Module)
# overwrite the drafter, if exists
self.clear_spec_dec()
self.drafter_model = drafter_model
self.drafter = Drafter(
self.drafter_model,
self.tokenizer,
device=self.device,
dtype=self.dtype,
)
# check if the provided drafter model is compatible with GLIDE structure
# when `use_glide_drafter` is set to True
if (
use_glide_drafter
and hasattr(drafter_model, "model")
and hasattr(drafter_model.model, "layers")
and hasattr(drafter_model.model.layers[0], "cross_attn")
):
self.use_glide = use_glide_drafter
elif use_glide_drafter:
self.logger.warning(
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
f"but the provided drafter model is not compatible with GLIDE structure."
f"Falling back to use the default drafter model (non-GLIDE)."
)
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
def disable_spec_dec(self) -> None:
"""Disable using speculative decoding for subsequent generations."""
self.request_handler.unset_spec_dec_mode()
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_glide = False
self.use_spec_dec = False
def clear_spec_dec(self) -> None:
"""Clear relatable structures of speculative decoding, if exist."""
if self.use_spec_dec:
self.disable_spec_dec()
if self.drafter_model or self.drafter:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_glide = False
self.use_spec_dec = False
def steps_spec_dec(self) -> List[Sequence]:
"""
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
with many steps of speculating by a drafter model as well as verifying by a main model.
Returns:
List[Sequence]: finished sequences generated by one step.
"""
batch = self.request_handler.schedule() # prefill batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
drafter_out = self.drafter.speculate(input_token_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
# 2. Prefill main model (Verifier) - fill past kv cache for main model
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# append new inputs to the batch, temporarily
batch.append_batch_tokens(next_tokens)
self.request_handler.allocate_batch_spec_dec(batch, 1)
already_allocated_kv_len = batch.seq_lengths[0].item()
input_token_ids = batch.get_1D_inputs_spec_dec(1)
finished_sequences = self.request_handler.update()
while True:
# HACK Retrieve the running batch
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch
batch = self.request_handler.running_bb # running batch
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
# 3. Decoding - Drafter model speculates `n` tokens
glide_input = None
if self.use_glide:
glide_input = GlideInput(
batch.get_block_table_tensor(),
self.k_cache[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
n_spec_tokens=self.n_spec_tokens,
)
drafter_out = self.drafter.speculate(
input_token_ids,
self.n_spec_tokens,
drafter_past_key_values,
glide_input=glide_input,
)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
drafter_spec_length = drafter_out.speculated_length
for next_token_id_spec in next_token_ids_spec:
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0))
cur_length = batch.seq_lengths[0].item()
if already_allocated_kv_len < cur_length:
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len)
already_allocated_kv_len = cur_length
# 4. Decoding - Main model verifies `n` tokens in parallel
if drafter_spec_length < batch.num_tokens_to_verify:
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
# 5. Compare and process the results
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item()
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
# trim past key values of the drafter model
drafter_past_key_values = Drafter.trim_kv_cache(
drafter_past_key_values, drafter_spec_length - n_matches - 1
)
# prepare inputs for the next round of speculation
n = 1 if n_matches < drafter_spec_length else 2
input_token_ids = batch.get_1D_inputs_spec_dec(n)
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config)
finished_sequences = self.request_handler.update()
if len(finished_sequences) > 0:
break
# Reset back the number of speculated tokens of the batch,
# this is used to handle the last round of speculation, in which case the number of speculated tokens
# by the drafter is less than the number of speculated tokens set to the engine.
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
return finished_sequences
def generate(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> Union[List[str], Tuple[List[str], List[List[int]]]]:
"""
Executing the inference step.
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None.
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False.
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation.
"""
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
prompts = [prompts] if isinstance(prompts, str) else prompts
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids
with torch.inference_mode():
if prompts is not None or prompts_token_ids is not None:
self.add_request(
request_ids=request_ids,
prompts=prompts,
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = []
total_tokens_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
self.generation_config_dict = gen_config_dict
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
while self.request_handler.check_unfinished_reqs():
output_seqs_list += self.steps_spec_dec()
else:
while self.request_handler.check_unfinished_reqs():
output_seqs_list += self.step()
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
for seq in output_seqs_list:
total_tokens_list.append(seq.input_token_id + seq.output_token_id)
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True)
if return_token_ids:
output_tokens_list = [seq.output_token_id for seq in output_seqs_list]
return output_str, output_tokens_list
else:
return output_str
@property
def has_prompt_template(self) -> bool:
""" """
return self.inference_config.prompt_template is not None
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
"""
This method will format the input prompt according to the prompt template given to the InferenceConfig.
"""
assert (
self.has_prompt_template
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str):
return self.inference_config.prompt_template.format(input_text=prompts)
else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
def add_request(
self,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
) -> None:
"""
Add requests.
Args:
request_ids (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
"""
# apply the prompt template to the input prompts
if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts)
block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
# list of torch Tensor
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
raise TypeError(
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}."
)
assert (
len(prompts_token_ids[0]) <= self.inference_config.max_input_len
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}."
prompts_num = len(prompts_token_ids)
for i in range(prompts_num):
if request_ids:
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
assert len(request_ids) == prompts_num
request_id = request_ids[i]
else:
request_id = next(self.counter)
if prompts == None:
prompt = None
else:
prompt = prompts[i]
max_length = kwargs.get("max_length", None)
max_new_tokens = kwargs.get("max_new_tokens", None)
if max_length is None and max_new_tokens is None:
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])
if not self.inference_config.enable_streamingllm:
assert (
self.inference_config.max_output_len >= max_new_tokens
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
self.tokenizer.eos_token_id,
self.tokenizer.pad_token_id,
max_output_len=max_new_tokens,
ignore_eos=self.inference_config.ignore_eos,
)
self.request_handler.add_sequence(sequence)
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]:
input_ids = batch.get_1D_inputs()
sequence_lengths = batch.get_sequence_lengths()
if batch.is_prompts:
n_tokens = sequence_lengths.sum().item()
else:
n_tokens = batch.current_batch_size
if batch.use_spec_dec:
n_tokens = batch.num_tokens_to_verify + 1
assert n_tokens == input_ids.size(0)
n_tokens = n_tokens * batch.current_batch_size
output_tensor = torch.zeros(
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
batch_token_ids = None
if (
self.generation_config.repetition_penalty != 1.0
or self.generation_config.no_repeat_ngram_size > 0
or self.generation_config.forced_eos_token_id is not None
):
batch_token_ids = batch.batch_token_ids
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
use_cuda_graph = False
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
use_cuda_graph = True
input_meta_data = InputMetaData(
block_tables=batch.get_block_table_tensor(),
sequence_lengths=sequence_lengths,
fd_inter_tensor=batch.fd_inter_tensor,
batch_size=batch.current_batch_size,
is_prompts=batch.is_prompts,
use_cuda_kernel=self.inference_config.use_cuda_kernel,
use_cuda_graph=use_cuda_graph,
high_precision=self.high_precision,
kv_seq_len=sequence_lengths.max().item(),
head_dim=batch.head_dim,
dtype=batch.dtype,
use_spec_dec=batch.use_spec_dec,
num_tokens_to_verify=batch.num_tokens_to_verify,
batch_token_ids=batch_token_ids,
)
return input_ids, output_tensor, input_meta_data
def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Get the input, inputinfo and output placeholder from the batchbucket
3. Run model to generate the next token
4. Update waiting list and running list in RequestHandler and get finished sequences.
5. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
"""
batch = self.request_handler.schedule()
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
if self.inference_config.enable_streamingllm:
updated_block_ids = batch.streamingllm_update_batch(
self.inference_config.start_token_size, self.inference_config.generated_token_size
)
self.request_handler.streamingllm_free_block_tables(updated_block_ids)
next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update()
return finished_sequences

51
colossalai/inference/core/request_handler.py

@ -8,7 +8,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
from colossalai.inference.struct import RequestStatus, Sequence
from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
@ -98,7 +98,46 @@ class RunningList:
self._decoding[seq_id] = self._prefill.pop(seq_id)
class RequestHandler:
class NaiveRequestHandler:
def __init__(self) -> None:
self.running_list: List[DiffusionSequence] = []
self.waiting_list: List[str] = []
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def _has_running(self) -> bool:
return any(lst for lst in self.running_list)
def check_unfinished_reqs(self):
return self._has_waiting() or self._has_running()
def add_sequence(self, seq: DiffusionSequence):
"""
Add the request to waiting list.
"""
assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
self.waiting_list.append(seq)
def _find_sequence(self, request_id: int) -> DiffusionSequence:
"""
Find the request by request_id.
"""
for lst in enumerate(self.waiting_list + self.running_list):
for seq in lst:
if seq.request_id == request_id:
return seq
return None
def schedule(self):
ret = None
if self._has_waiting:
ret = self.waiting_list[0]
self.waiting_list = self.waiting_list[1:]
return ret
class RequestHandler(NaiveRequestHandler):
"""
RequestHandler is the core for handling existing requests and updating current batch.
During generation process, we call schedule function each iteration to update current batch.
@ -176,12 +215,12 @@ class RequestHandler:
generated_token_size=inference_config.generated_token_size,
)
def _has_running(self) -> bool:
return not self.running_bb.is_empty()
def _init_cache(self, model_config):
self.cache_manager = KVCacheManager(self.inference_config, model_config)
def _has_waiting(self) -> bool:
return any(lst for lst in self.waiting_list)
def get_kvcache(self):
return self.cache_manager.get_kv_cache()
@ -318,7 +357,7 @@ class RequestHandler:
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
seq.mark_finished()
def check_unfinished_seqs(self) -> bool:
def check_unfinished_reqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
def total_requests_in_batch_bucket(self) -> int:

54
colossalai/inference/modeling/models/diffusion.py

@ -0,0 +1,54 @@
import inspect
import types
import torch
from torch import nn
class DiffusionPipe(nn.Module):
"""
This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property.
"""
def __init__(self, source_obj) -> None:
super(DiffusionPipe, self).__init__()
for k, v in source_obj.__dict__.items():
if isinstance(v, nn.Module):
self.add_module(k, v)
else:
setattr(self, k, v)
skip_list = ["_execution_device", "to", "device"] # this
for name, member in inspect.getmembers(source_obj.__class__):
if name in skip_list:
continue
if not name.startswith("__") and not name.endswith("__"):
if isinstance(member, property):
setattr(self.__class__, name, member)
elif inspect.isfunction(member) or inspect.ismethod(member):
bound_method = types.MethodType(member, self)
setattr(self, name, bound_method)
elif not callable(member) and not isinstance(member, property):
setattr(self, name, member)
elif name == "__call__":
bound_method = types.MethodType(member, self)
setattr(self, "_forward", bound_method)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
Accelerate's module hooks.
"""
# return self.device
return torch.device("cuda")
@property
def device(self):
next(self.parameters()).device
def forward(self, *args, **kwargs):
return self._forward(*args, **kwargs)

220
colossalai/inference/modeling/models/pixart_alpha.py

@ -0,0 +1,220 @@
# Code adapted from:
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
from typing import Callable, List, Optional, Union
import PIL.Image
import torch
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_256_BIN,
ASPECT_RATIO_512_BIN,
ASPECT_RATIO_1024_BIN,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
from colossalai.logging import get_dist_logger
from .diffusion import DiffusionPipe
logger = get_dist_logger(__name__)
@torch.no_grad()
def pixart_alpha_forward(
self: DiffusionPipe,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
max_sequence_length: int = 120,
**kwargs,
) -> PIL.Image:
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.transformer.config.sample_size == 128:
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
if do_classifier_free_guidance:
resolution = torch.cat([resolution, resolution], dim=0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
if num_inference_steps == 1:
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
else:
image = latents
if not output_type == "latent":
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
# self.maybe_free_model_hooks()
return image

178
colossalai/inference/modeling/models/stablediffusion3.py

@ -0,0 +1,178 @@
# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
from .diffusion import DiffusionPipe
# TODO(@lry89757) temporarily image, please support more return output
@torch.no_grad()
def sd3_forward(
self: DiffusionPipe,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
prompt_3,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
return image

6
colossalai/inference/modeling/policy/__init__.py

@ -1,16 +1,22 @@
from .glide_llama import GlideLlamaModelPolicy
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
from .pixart_alpha import PixArtAlphaInferPolicy
from .stablediffusion3 import StableDiffusion3InferPolicy
model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
"glide_llama": GlideLlamaModelPolicy,
"StableDiffusion3Pipeline": StableDiffusion3InferPolicy,
"PixArtAlphaPipeline": PixArtAlphaInferPolicy,
}
__all__ = [
"NoPaddingLlamaModelInferPolicy",
"NoPaddingBaichuanModelInferPolicy",
"GlideLlamaModelPolicy",
"StableDiffusion3InferPolicy",
"PixArtAlphaInferPolicy",
"model_polic_map",
]

34
colossalai/inference/modeling/policy/pixart_alpha.py

@ -0,0 +1,34 @@
from torch import nn
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
from colossalai.shardformer.policies.base_policy import Policy
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = {}
self.append_or_create_method_replacement(
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
)
return policy
def preprocess(self) -> nn.Module:
return self.model
def postprocess(self):
return self.model
def config_sanity_check(self):
pass
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "PixArtAlphaInferPolicy":
return PixArtAlphaInferPolicy()

34
colossalai/inference/modeling/policy/stablediffusion3.py

@ -0,0 +1,34 @@
from torch import nn
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
from colossalai.shardformer.policies.base_policy import Policy
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = {}
self.append_or_create_method_replacement(
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
)
return policy
def preprocess(self) -> nn.Module:
return self.model
def postprocess(self):
return self.model
def config_sanity_check(self):
pass
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "StableDiffusion3InferPolicy":
return StableDiffusion3InferPolicy()

12
colossalai/inference/struct.py

@ -2,6 +2,7 @@ import enum
from dataclasses import dataclass
from typing import Any, List
from colossalai.inference.config import DiffusionGenerationConfig
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
@ -46,6 +47,17 @@ class RequestStatus(enum.Enum):
return status == RequestStatus.WAITING
@dataclass
class DiffusionSequence:
"""
parameters for diffusion
"""
request_id: int
prompt: str
generation_config: DiffusionGenerationConfig
@dataclass
class Sequence:
"""Store information of input sequence.

39
colossalai/inference/utils.py

@ -5,10 +5,12 @@ Utils for model inference
import math
import os
import re
from enum import Enum
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline
from torch import nn
from colossalai.logging import get_dist_logger
@ -159,3 +161,38 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
return False
class ModelType(Enum):
DIFFUSION_MODEL = "Diffusion Model"
LLM = "Large Language Model (LLM)"
UNKNOWN = "Unknown Model Type"
def get_model_type(model_or_path: Union[nn.Module, str, DiffusionPipeline]):
if isinstance(model_or_path, DiffusionPipeline):
return ModelType.DIFFUSION_MODEL
elif isinstance(model_or_path, nn.Module):
return ModelType.LLM
elif isinstance(model_or_path, str):
try:
from transformers import AutoConfig
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
return ModelType.LLM
except:
"""
model type is not `ModelType.LLM`
"""
try:
from diffusers import DiffusionPipeline
DiffusionPipeline.load_config(model_or_path)
return ModelType.DIFFUSION_MODEL
except:
"""
model type is not `ModelType.DIFFUSION_MODEL`
"""
else:
return ModelType.UNKNOWN

75
examples/inference/stable_diffusion/sd3_generation.py

@ -0,0 +1,75 @@
import argparse
from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
from torch import bfloat16, float16, float32
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
# For Stable Diffusion 3, we'll use the following configuration
MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
TORCH_DTYPE_MAP = {
"fp16": float16,
"fp32": float32,
"bf16": bfloat16,
}
def infer(args):
# ==============================
# Launch colossalai, setup distributed environment
# ==============================
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================
# Load model and tokenizer
# ==============================
model_path_or_name = args.model
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
# ==============================
# Initialize InferenceEngine
# ==============================
coordinator.print_on_master(f"Initializing Inference Engine...")
inference_config = InferenceConfig(
dtype=args.dtype,
max_batch_size=args.max_batch_size,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
)
engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
# ==============================
# Generation
# ==============================
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
out.save("cat.jpg")
coordinator.print_on_master(out)
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
if __name__ == "__main__":
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt")
parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
args = parser.parse_args()
infer(args)

1
requirements/requirements.txt

@ -23,3 +23,4 @@ rpyc==6.0.0
fastapi
uvicorn==0.29.0
galore_torch
diffusers==0.29.0

Loading…
Cancel
Save