mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* Diffusion Model Inference support * Stable Diffusion 3 Support * pixartalpha supportpull/5894/head
Runyu Lu
5 months ago
committed by
GitHub
16 changed files with 1860 additions and 740 deletions
@ -0,0 +1,90 @@
|
||||
from abc import ABC, abstractmethod |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
|
||||
from colossalai.cluster import ProcessGroupMesh |
||||
from colossalai.inference.config import ModelShardInferenceConfig |
||||
from colossalai.pipeline.stage_manager import PipelineStageManager |
||||
from colossalai.shardformer import ShardConfig, ShardFormer |
||||
from colossalai.shardformer.policies.base_policy import Policy |
||||
|
||||
|
||||
class BaseEngine(ABC): |
||||
@abstractmethod |
||||
def __init__(self, model_or_path, inference_config=None, verbose=False, model_policy=None): |
||||
pass |
||||
|
||||
@abstractmethod |
||||
def init_model(self, model_or_path, model_policy=None, model_shard_infer_config=None): |
||||
""" |
||||
Init Model for Engine |
||||
""" |
||||
|
||||
@abstractmethod |
||||
def generate(self, request_ids=None, prompts=None, generation_config=None, **kwargs): |
||||
""" |
||||
Generate ouptput for coming requests |
||||
""" |
||||
|
||||
@abstractmethod |
||||
def add_request(self, prompts, request_ids=None, **kwargs): |
||||
""" |
||||
Add new request to Engine |
||||
""" |
||||
|
||||
@abstractmethod |
||||
def step(self): |
||||
""" |
||||
Perform one new step forward |
||||
""" |
||||
|
||||
@abstractmethod |
||||
def _verify_args(self): |
||||
""" |
||||
Verify the parameters and members of class |
||||
""" |
||||
|
||||
@torch.inference_mode() |
||||
def capture_model(self): |
||||
""" |
||||
Use cuda graph to capture model |
||||
""" |
||||
return NotImplementedError("This method should be implemented by subclasses") |
||||
|
||||
def _shardformer( |
||||
self, |
||||
model: nn.Module, |
||||
model_policy: Policy, |
||||
model_shard_infer_config: ModelShardInferenceConfig = None, |
||||
stage_manager: PipelineStageManager = None, |
||||
tp_group: ProcessGroupMesh = None, |
||||
**kwargs, |
||||
) -> nn.Module: |
||||
""" |
||||
Initialize ShardConfig and replace the model with shardformer. |
||||
|
||||
Args: |
||||
model (nn.Module): Path or nn.Module of this model. |
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type. |
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. |
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. |
||||
|
||||
Returns: |
||||
nn.Module: The model optimized by Shardformer. |
||||
""" |
||||
|
||||
shardconfig = ShardConfig( |
||||
tensor_parallel_process_group=tp_group, |
||||
pipeline_stage_manager=stage_manager, |
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1), |
||||
enable_fused_normalization=False, |
||||
enable_all_optimization=False, |
||||
enable_flash_attention=False, |
||||
enable_jit_fused=False, |
||||
enable_sequence_parallelism=False, |
||||
extra_kwargs={"model_shard_infer_config": model_shard_infer_config, **kwargs}, |
||||
) |
||||
shardformer = ShardFormer(shard_config=shardconfig) |
||||
shard_model, _ = shardformer.optimize(model, model_policy) |
||||
return shard_model |
@ -0,0 +1,200 @@
|
||||
from itertools import count |
||||
from typing import List, Tuple, Type, Union |
||||
|
||||
import numpy as np |
||||
import PIL.Image |
||||
import torch |
||||
import torch.nn as nn |
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
||||
from torch import distributed as dist |
||||
|
||||
from colossalai.accelerator import get_accelerator |
||||
from colossalai.cluster import ProcessGroupMesh |
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig |
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe |
||||
from colossalai.inference.modeling.policy import model_policy_map |
||||
from colossalai.inference.struct import DiffusionSequence |
||||
from colossalai.inference.utils import get_model_size, get_model_type |
||||
from colossalai.logging import get_dist_logger |
||||
from colossalai.shardformer.policies.base_policy import Policy |
||||
|
||||
from .base_engine import BaseEngine |
||||
from .request_handler import NaiveRequestHandler |
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1 |
||||
|
||||
|
||||
class DiffusionEngine(BaseEngine): |
||||
def __init__( |
||||
self, |
||||
model_or_path: DiffusionPipeline | str, |
||||
inference_config: InferenceConfig = None, |
||||
verbose: bool = False, |
||||
model_policy: Policy | type[Policy] = None, |
||||
) -> None: |
||||
self.inference_config = inference_config |
||||
self.dtype = inference_config.dtype |
||||
self.high_precision = inference_config.high_precision |
||||
|
||||
self.verbose = verbose |
||||
self.logger = get_dist_logger(__name__) |
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config() |
||||
|
||||
self.model_type = get_model_type(model_or_path=model_or_path) |
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config) |
||||
|
||||
self.request_handler = NaiveRequestHandler() |
||||
|
||||
self.counter = count() |
||||
|
||||
self._verify_args() |
||||
|
||||
def _verify_args(self) -> None: |
||||
assert isinstance(self.model, DiffusionPipe), "model must be DiffusionPipe" |
||||
|
||||
def init_model( |
||||
self, |
||||
model_or_path: Union[str, nn.Module, DiffusionPipeline], |
||||
model_policy: Union[Policy, Type[Policy]] = None, |
||||
model_shard_infer_config: ModelShardInferenceConfig = None, |
||||
): |
||||
""" |
||||
Shard model or/and Load weight |
||||
|
||||
Args: |
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. |
||||
model_policy (Policy): the policy to replace the model. |
||||
model_inference_config: the configuration for modeling initialization when inference. |
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. |
||||
""" |
||||
if isinstance(model_or_path, str): |
||||
model = DiffusionPipeline.from_pretrained(model_or_path, torch_dtype=self.dtype) |
||||
policy_map_key = model.__class__.__name__ |
||||
model = DiffusionPipe(model) |
||||
elif isinstance(model_or_path, DiffusionPipeline): |
||||
policy_map_key = model_or_path.__class__.__name__ |
||||
model = DiffusionPipe(model_or_path) |
||||
else: |
||||
self.logger.error(f"model_or_path support only str or DiffusionPipeline currently!") |
||||
|
||||
torch.cuda.empty_cache() |
||||
init_gpu_memory = torch.cuda.mem_get_info()[0] |
||||
|
||||
self.device = get_accelerator().get_current_device() |
||||
if self.verbose: |
||||
self.logger.info(f"the device is {self.device}") |
||||
|
||||
if self.verbose: |
||||
self.logger.info( |
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" |
||||
) |
||||
|
||||
if model_policy is None: |
||||
model_policy = model_policy_map.get(policy_map_key) |
||||
|
||||
if not isinstance(model_policy, Policy): |
||||
try: |
||||
model_policy = model_policy() |
||||
except Exception as e: |
||||
raise ValueError(f"Unable to instantiate model policy: {e}") |
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" |
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) |
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS) |
||||
|
||||
self.model = self._shardformer( |
||||
model, |
||||
model_policy, |
||||
model_shard_infer_config, |
||||
None, |
||||
tp_group=tp_group, |
||||
) |
||||
|
||||
self.model = model.to(self.device) |
||||
|
||||
if self.verbose: |
||||
self.logger.info( |
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" |
||||
) |
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info() |
||||
peak_memory = init_gpu_memory - free_gpu_memory |
||||
if self.verbose: |
||||
self.logger.info( |
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" |
||||
) |
||||
|
||||
def generate( |
||||
self, |
||||
request_ids: Union[List[int], int] = None, |
||||
prompts: Union[List[str], str] = None, |
||||
generation_config: DiffusionGenerationConfig = None, |
||||
**kwargs, |
||||
) -> Union[List[Union[str, List[PIL.Image.Image], np.ndarray]], Tuple[List[str], List[List[int]]]]: |
||||
""" """ |
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {} |
||||
prompts = [prompts] if isinstance(prompts, str) else prompts |
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids |
||||
|
||||
with torch.inference_mode(): |
||||
if prompts is not None: |
||||
self.add_request( |
||||
request_ids=request_ids, |
||||
prompts=prompts, |
||||
**gen_config_dict, |
||||
**kwargs, |
||||
) |
||||
|
||||
output_reqs_list = [] |
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one. |
||||
if generation_config is not None: |
||||
self.generation_config = generation_config |
||||
self.generation_config_dict = gen_config_dict |
||||
|
||||
while self.request_handler.check_unfinished_reqs(): |
||||
output_reqs_list += self.step() |
||||
|
||||
return output_reqs_list |
||||
|
||||
def add_request( |
||||
self, |
||||
prompts: Union[List[str], str], |
||||
request_ids: Union[List[int], int] = None, |
||||
**kwargs, |
||||
): |
||||
if request_ids is not None and not isinstance(request_ids, list): |
||||
request_ids = [request_ids] |
||||
|
||||
if not isinstance(prompts, list): |
||||
prompts = [prompts] |
||||
|
||||
generation_config = DiffusionGenerationConfig.from_kwargs(**kwargs) |
||||
prompts_num = len(prompts) |
||||
for i in range(prompts_num): |
||||
if request_ids: |
||||
assert isinstance( |
||||
request_ids[0], int |
||||
), f"The request_id type must be int, but got {type(request_ids[0])}" |
||||
assert len(request_ids) == prompts_num |
||||
request_id = request_ids[i] |
||||
else: |
||||
request_id = next(self.counter) |
||||
|
||||
seq = DiffusionSequence(request_id=request_id, prompt=prompts[i], generation_config=generation_config) |
||||
|
||||
self.request_handler.add_sequence(seq) |
||||
|
||||
def step(self) -> List[PIL.Image.Image]: |
||||
""" |
||||
In each step, do the follows: |
||||
1. Run RequestHandler.schedule() and get the batch used for inference. |
||||
2. run forward to get List[Image] |
||||
Returns: |
||||
List[PIL.Image.Image]: Image Generated by one step. |
||||
""" |
||||
|
||||
input = self.request_handler.schedule() |
||||
ret = self.model(prompt=input.prompt, **input.generation_config.to_dict()) |
||||
return ret |
@ -0,0 +1,758 @@
|
||||
import time |
||||
from itertools import count |
||||
from typing import Dict, List, Optional, Tuple, Type, Union |
||||
|
||||
import numpy as np |
||||
import torch |
||||
import torch.nn as nn |
||||
from torch import distributed as dist |
||||
from transformers import ( |
||||
AutoConfig, |
||||
AutoModelForCausalLM, |
||||
GenerationConfig, |
||||
PreTrainedTokenizer, |
||||
PreTrainedTokenizerFast, |
||||
) |
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM |
||||
|
||||
from colossalai.accelerator import get_accelerator |
||||
from colossalai.cluster import ProcessGroupMesh |
||||
from colossalai.inference.batch_bucket import BatchBucket |
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelShardInferenceConfig |
||||
from colossalai.inference.graph_runner import CUDAGraphRunner |
||||
from colossalai.inference.modeling.policy import model_policy_map |
||||
from colossalai.inference.sampler import search_tokens |
||||
from colossalai.inference.spec import Drafter, GlideInput |
||||
from colossalai.inference.struct import Sequence |
||||
from colossalai.inference.utils import get_model_size, has_index_file |
||||
from colossalai.interface import ModelWrapper |
||||
from colossalai.lazy import LazyInitContext |
||||
from colossalai.logging import get_dist_logger |
||||
from colossalai.shardformer.policies.base_policy import Policy |
||||
|
||||
from .base_engine import BaseEngine |
||||
from .request_handler import RequestHandler |
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1 |
||||
|
||||
_supported_models = { |
||||
"LlamaForCausalLM": LlamaForCausalLM, |
||||
"BaichuanForCausalLM": AutoModelForCausalLM, |
||||
} |
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] |
||||
|
||||
|
||||
class LLMEngine(BaseEngine): |
||||
""" |
||||
InferenceEngine which manages the inference process.. |
||||
|
||||
Args: |
||||
model_or_path (nn.Module or str): Path or nn.Module of this model. |
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. |
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. |
||||
verbose (bool): Determine whether or not to log the generation process. |
||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
model_or_path: nn.Module | str, |
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None, |
||||
inference_config: InferenceConfig = None, |
||||
verbose: bool = False, |
||||
model_policy: Policy | type[Policy] = None, |
||||
) -> None: |
||||
self.inference_config = inference_config |
||||
self.dtype = inference_config.dtype |
||||
self.high_precision = inference_config.high_precision |
||||
|
||||
self.verbose = verbose |
||||
self.logger = get_dist_logger(__name__) |
||||
self.model_shard_infer_config = inference_config.to_model_shard_inference_config() |
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_shard_infer_config) |
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config) |
||||
self.generation_config_dict = self.generation_config.to_dict() |
||||
|
||||
self.tokenizer = tokenizer |
||||
self.tokenizer.pad_token = self.tokenizer.eos_token |
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config) |
||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache() |
||||
# DISCUSS maybe move this into batch info? |
||||
|
||||
self.counter = count() |
||||
|
||||
self.use_cuda_graph = self.inference_config.use_cuda_graph |
||||
if self.use_cuda_graph: |
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {} |
||||
self.graph_memory_pool = None # Set during graph capture. |
||||
if verbose: |
||||
self.logger.info("Colossal AI CUDA Graph Capture on") |
||||
|
||||
self.capture_model(self.k_cache, self.v_cache) |
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` |
||||
self.use_spec_dec = self.inference_config.use_spec_dec |
||||
|
||||
self.drafter_model = None |
||||
self.drafter = None |
||||
self.use_glide = False |
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens |
||||
|
||||
self._verify_args() |
||||
|
||||
def init_model( |
||||
self, |
||||
model_or_path: Union[nn.Module, str], |
||||
model_policy: Union[Policy, Type[Policy]] = None, |
||||
model_shard_infer_config: ModelShardInferenceConfig = None, |
||||
): |
||||
""" |
||||
Shard model or/and Load weight |
||||
|
||||
Args: |
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. |
||||
model_policy (Policy): the policy to replace the model. |
||||
model_inference_config: the configuration for modeling initialization when inference. |
||||
model_shard_infer_config (ModelShardInferenceConfig): the configuration for init of module when inference. |
||||
""" |
||||
pretrained_path = None |
||||
if isinstance(model_or_path, str): |
||||
import colossalai.interface.pretrained as pretrained_utils |
||||
|
||||
try: |
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True, torch_dtype=self.dtype) |
||||
arch = getattr(hf_config, "architectures")[0] |
||||
if arch in _supported_models.keys(): |
||||
if arch == "BaichuanForCausalLM": |
||||
self.logger.warning( |
||||
"Attention ! We use lazy init by default, which could be faster for model loading. For baichuan model, the output maybe have a slight difference with transformers" |
||||
) |
||||
ctx = LazyInitContext(default_device="cuda") |
||||
with ctx: |
||||
model = _supported_models[arch].from_pretrained( |
||||
model_or_path, trust_remote_code=True, torch_dtype=self.dtype |
||||
) |
||||
pretrained_path = pretrained_utils.get_pretrained_path(model) |
||||
else: |
||||
# TODO(char-1ee): if the model not supported, use transformers APIs to load and generate |
||||
raise ValueError(f"Model {arch} is not supported.") |
||||
|
||||
except Exception as e: |
||||
self.logger.error( |
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" |
||||
) |
||||
else: |
||||
model = model_or_path |
||||
|
||||
self.model_config = model.config |
||||
|
||||
torch.cuda.empty_cache() |
||||
init_gpu_memory = torch.cuda.mem_get_info()[0] |
||||
|
||||
self.device = get_accelerator().get_current_device() |
||||
if self.verbose: |
||||
self.logger.info(f"the device is {self.device}") |
||||
|
||||
model = model.to(self.dtype).eval() |
||||
|
||||
if self.verbose: |
||||
self.logger.info( |
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" |
||||
) |
||||
|
||||
if model_policy is None: |
||||
prefix = "nopadding" if not self.inference_config.pad_input else "padding" |
||||
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" |
||||
model_policy = model_policy_map.get(model_policy_key) |
||||
|
||||
if not isinstance(model_policy, Policy): |
||||
try: |
||||
model_policy = model_policy() |
||||
except Exception as e: |
||||
raise ValueError(f"Unable to instantiate model policy: {e}") |
||||
|
||||
assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" |
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) |
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS) |
||||
|
||||
self.model = self._shardformer( |
||||
model, |
||||
model_policy, |
||||
model_shard_infer_config, |
||||
None, |
||||
tp_group=tp_group, |
||||
) |
||||
|
||||
self.model = ModelWrapper(model).to(self.device) |
||||
|
||||
if self.verbose: |
||||
self.logger.info( |
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" |
||||
) |
||||
|
||||
if pretrained_path: |
||||
from colossalai.inference.core.plugin import InferCheckpoint_io |
||||
|
||||
cpt_io = InferCheckpoint_io() |
||||
if_has_index_file, model_index_file = has_index_file(pretrained_path) |
||||
assert if_has_index_file, "the model path is invalid" |
||||
cpt_io.load_model(self.model, model_index_file) |
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info() |
||||
peak_memory = init_gpu_memory - free_gpu_memory |
||||
if self.verbose: |
||||
self.logger.info( |
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" |
||||
) |
||||
|
||||
@torch.inference_mode() |
||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]): |
||||
assert self.use_cuda_graph, "please turn on the cuda graph" |
||||
|
||||
if self.verbose: |
||||
self.logger.info("Colossal AI CUDA Graph Capture begin") |
||||
|
||||
t_capture_begin = time.perf_counter() |
||||
|
||||
block_size = self.inference_config.block_size |
||||
head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads |
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes. |
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) |
||||
max_context_len_to_capture = self.inference_config.max_context_len_to_capture |
||||
max_num_blocks = (max_context_len_to_capture + block_size - 1) // block_size |
||||
input_tokens_ids = torch.zeros(max_batch_size, dtype=torch.long).cuda() |
||||
# self.graph_block_tables = np.zeros((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) |
||||
self.graph_block_tables = np.full((max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), -1, dtype=np.int32) |
||||
self.graph_block_tables[:, 0] = np.arange(max_num_blocks, max_num_blocks + max(_BATCH_SIZES_TO_CAPTURE)) |
||||
self.graph_block_tables[0, :] = np.arange( |
||||
0, max_num_blocks |
||||
) # NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len |
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda() |
||||
output_tensor = torch.zeros( |
||||
(max_batch_size, self.model_config.num_attention_heads * head_dim), dtype=self.dtype, device=self.device |
||||
) |
||||
fd_inter_tensor = self.request_handler.running_bb.fd_inter_tensor |
||||
|
||||
max_num_seqs = self.inference_config.max_batch_size |
||||
batch_size_capture_list = [bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= max_num_seqs] |
||||
sequence_lengths = torch.ones(max_batch_size, dtype=torch.int).cuda() |
||||
# NOTE this is a hack to insure cuda grpah could capture the fixed cuda kernel grid in flash decoding, to make the first seqlen as the max_seq_len |
||||
sequence_lengths[0] = torch.tensor( |
||||
self.inference_config.max_context_len_to_capture - 1, dtype=torch.int32 |
||||
).cuda() |
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the |
||||
# memory usage of CUDA graph. |
||||
for batch_size in reversed(batch_size_capture_list): |
||||
if self.verbose: |
||||
self.logger.info(f"batch size {batch_size} graph capturing") |
||||
|
||||
input_meta_data = InputMetaData( |
||||
block_tables=block_tables[:batch_size], |
||||
sequence_lengths=sequence_lengths[:batch_size], |
||||
fd_inter_tensor=fd_inter_tensor, |
||||
batch_size=batch_size, |
||||
is_prompts=False, |
||||
use_cuda_graph=True, |
||||
high_precision=False, |
||||
kv_seq_len=sequence_lengths[:batch_size].max().item(), |
||||
head_dim=head_dim, |
||||
dtype=self.dtype, |
||||
) |
||||
|
||||
graph_runner = CUDAGraphRunner(self.model) |
||||
graph_runner.capture( |
||||
input_tokens_ids[:batch_size], |
||||
output_tensor[:batch_size], |
||||
input_meta_data, |
||||
k_caches=k_cache, |
||||
v_caches=v_cache, |
||||
memory_pool=self.graph_memory_pool, |
||||
) |
||||
self.graph_memory_pool = graph_runner.graph.pool() |
||||
self.graph_runners[batch_size] = graph_runner |
||||
|
||||
t_capture_end = time.perf_counter() |
||||
|
||||
if self.verbose: |
||||
self.logger.info(f"CUDA Graph capture time: {t_capture_end - t_capture_begin} s") |
||||
|
||||
def _verify_args(self) -> None: |
||||
"""Verify the input args""" |
||||
if not isinstance(self.inference_config, InferenceConfig): |
||||
raise TypeError("Invalid type of inference config provided.") |
||||
if not isinstance(self.model, nn.Module): |
||||
raise TypeError(f"the model type must be nn.Module, but got {type(self.model)}") |
||||
if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): |
||||
raise TypeError( |
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" |
||||
) |
||||
if isinstance(self.model, ModelWrapper): |
||||
model = self.model.module |
||||
assert ( |
||||
model.__class__.__name__ in _supported_models.keys() |
||||
), f"Model {self.model.__class__.__name__} is not supported." |
||||
|
||||
def enable_spec_dec( |
||||
self, |
||||
drafter_model: nn.Module = None, |
||||
n_spec_tokens: int = None, |
||||
use_glide_drafter: bool = False, |
||||
) -> None: |
||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations. |
||||
|
||||
Args: |
||||
drafter_model (nn.Module): The drafter model (small model) used to speculate tokens. |
||||
If provided, the previous drafter and drafter model, if exist, will be overwritten. |
||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying. |
||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used. |
||||
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False. |
||||
If True, the drafter model will be replaced by a glide model. |
||||
|
||||
```python |
||||
... |
||||
engine = InferenceEngine(model, tokenizer, inference_config) |
||||
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5) |
||||
engine.generate(...) # Speculative Decoding |
||||
|
||||
engine.disable_spec_dec() |
||||
engine.generate(...) # Normal generation |
||||
|
||||
engine.enable_spec_dec() |
||||
engine.generate(...) # Speculative-Decoding using previously set drafter model and number of spec tokens |
||||
engine.clear_spec_dec() |
||||
``` |
||||
""" |
||||
|
||||
if drafter_model is None and self.drafter is None: |
||||
raise ValueError("Drafter not initialized. Please provide a Drafter Model") |
||||
if n_spec_tokens is not None: |
||||
assert 1 < n_spec_tokens <= self.inference_config.max_n_spec_tokens |
||||
self.n_spec_tokens = n_spec_tokens |
||||
if drafter_model is not None: |
||||
assert isinstance(drafter_model, nn.Module) |
||||
# overwrite the drafter, if exists |
||||
self.clear_spec_dec() |
||||
self.drafter_model = drafter_model |
||||
self.drafter = Drafter( |
||||
self.drafter_model, |
||||
self.tokenizer, |
||||
device=self.device, |
||||
dtype=self.dtype, |
||||
) |
||||
|
||||
# check if the provided drafter model is compatible with GLIDE structure |
||||
# when `use_glide_drafter` is set to True |
||||
if ( |
||||
use_glide_drafter |
||||
and hasattr(drafter_model, "model") |
||||
and hasattr(drafter_model.model, "layers") |
||||
and hasattr(drafter_model.model.layers[0], "cross_attn") |
||||
): |
||||
self.use_glide = use_glide_drafter |
||||
elif use_glide_drafter: |
||||
self.logger.warning( |
||||
f"`use_glide_drafter` is provided as {use_glide_drafter}, " |
||||
f"but the provided drafter model is not compatible with GLIDE structure." |
||||
f"Falling back to use the default drafter model (non-GLIDE)." |
||||
) |
||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens) |
||||
# using speculative decoding for subsequent generations |
||||
self.use_spec_dec = True |
||||
|
||||
def disable_spec_dec(self) -> None: |
||||
"""Disable using speculative decoding for subsequent generations.""" |
||||
self.request_handler.unset_spec_dec_mode() |
||||
# set back to the maximum number of tokens to speculate |
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens |
||||
self.use_glide = False |
||||
self.use_spec_dec = False |
||||
|
||||
def clear_spec_dec(self) -> None: |
||||
"""Clear relatable structures of speculative decoding, if exist.""" |
||||
if self.use_spec_dec: |
||||
self.disable_spec_dec() |
||||
if self.drafter_model or self.drafter: |
||||
self.drafter_model = None |
||||
self.drafter = None |
||||
torch.cuda.empty_cache() |
||||
self.use_glide = False |
||||
self.use_spec_dec = False |
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]: |
||||
""" |
||||
Run Speculative Decoding steps. This is like retrieving a single batch and launch inference |
||||
with many steps of speculating by a drafter model as well as verifying by a main model. |
||||
|
||||
Returns: |
||||
List[Sequence]: finished sequences generated by one step. |
||||
""" |
||||
batch = self.request_handler.schedule() # prefill batch |
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." |
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) |
||||
|
||||
if input_meta_data.use_cuda_graph: |
||||
model_executable = self.graph_runners[input_meta_data.batch_size] |
||||
else: |
||||
model_executable = self.model |
||||
|
||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model |
||||
# NOTE For glide drafter models, we won't actually apply glide during prefill stage |
||||
drafter_out = self.drafter.speculate(input_token_ids, 1, None) |
||||
next_token_ids_spec = drafter_out.next_tokens |
||||
drafter_past_key_values = drafter_out.past_key_values |
||||
|
||||
# 2. Prefill main model (Verifier) - fill past kv cache for main model |
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) |
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) |
||||
# append new inputs to the batch, temporarily |
||||
batch.append_batch_tokens(next_tokens) |
||||
self.request_handler.allocate_batch_spec_dec(batch, 1) |
||||
already_allocated_kv_len = batch.seq_lengths[0].item() |
||||
input_token_ids = batch.get_1D_inputs_spec_dec(1) |
||||
|
||||
finished_sequences = self.request_handler.update() |
||||
|
||||
while True: |
||||
# HACK Retrieve the running batch |
||||
# Using RequestHandler.schedule here will re-allocate same kv cache for the batch |
||||
batch = self.request_handler.running_bb # running batch |
||||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." |
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens |
||||
glide_input = None |
||||
if self.use_glide: |
||||
glide_input = GlideInput( |
||||
batch.get_block_table_tensor(), |
||||
self.k_cache[-1], # use kv cahces of the last layer |
||||
self.v_cache[-1], |
||||
batch.get_sequence_lengths(), |
||||
n_spec_tokens=self.n_spec_tokens, |
||||
) |
||||
|
||||
drafter_out = self.drafter.speculate( |
||||
input_token_ids, |
||||
self.n_spec_tokens, |
||||
drafter_past_key_values, |
||||
glide_input=glide_input, |
||||
) |
||||
next_token_ids_spec = drafter_out.next_tokens |
||||
drafter_past_key_values = drafter_out.past_key_values |
||||
drafter_spec_length = drafter_out.speculated_length |
||||
|
||||
for next_token_id_spec in next_token_ids_spec: |
||||
self.request_handler.append_next_tokens(next_token_id_spec.unsqueeze(0)) |
||||
cur_length = batch.seq_lengths[0].item() |
||||
if already_allocated_kv_len < cur_length: |
||||
self.request_handler.allocate_batch_spec_dec(batch, n=cur_length - already_allocated_kv_len) |
||||
already_allocated_kv_len = cur_length |
||||
|
||||
# 4. Decoding - Main model verifies `n` tokens in parallel |
||||
if drafter_spec_length < batch.num_tokens_to_verify: |
||||
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length) |
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) |
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) |
||||
|
||||
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) |
||||
|
||||
# 5. Compare and process the results |
||||
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) |
||||
n_matches = drafter_spec_length if diff_indexes.size(0) == 0 else diff_indexes[0][0].item() |
||||
|
||||
# revoke appended tokens for each Sequence in the current batch |
||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens |
||||
|
||||
# append the last correct token generated by the main model |
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0)) |
||||
|
||||
# trim past key values of the drafter model |
||||
drafter_past_key_values = Drafter.trim_kv_cache( |
||||
drafter_past_key_values, drafter_spec_length - n_matches - 1 |
||||
) |
||||
|
||||
# prepare inputs for the next round of speculation |
||||
n = 1 if n_matches < drafter_spec_length else 2 |
||||
input_token_ids = batch.get_1D_inputs_spec_dec(n) |
||||
|
||||
self.request_handler.update_batch_finished(batch, generation_config=self.generation_config) |
||||
finished_sequences = self.request_handler.update() |
||||
if len(finished_sequences) > 0: |
||||
break |
||||
|
||||
# Reset back the number of speculated tokens of the batch, |
||||
# this is used to handle the last round of speculation, in which case the number of speculated tokens |
||||
# by the drafter is less than the number of speculated tokens set to the engine. |
||||
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens) |
||||
|
||||
return finished_sequences |
||||
|
||||
def generate( |
||||
self, |
||||
request_ids: Union[List[int], int] = None, |
||||
prompts: Union[List[str], str] = None, |
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, |
||||
return_token_ids: bool = False, |
||||
generation_config: Optional[GenerationConfig] = None, |
||||
) -> Union[List[str], Tuple[List[str], List[List[int]]]]: |
||||
""" |
||||
Executing the inference step. |
||||
|
||||
Args: |
||||
request_ids (List[int], optional): The request ID. Defaults to None. |
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None. |
||||
prompts_token_ids (Union[List[int], torch.Tensor, np.ndarray], optional): token ids of input prompts. Defaults to None. |
||||
return_token_ids (bool, optional): Whether to return output token ids. Defaults to False. |
||||
generation_config (Optional[GenerationConfig], optional): Huggingface GenerationConfig used for inference. Defaults to None. |
||||
|
||||
Returns: |
||||
Union[List[str], Tuple[List[str], List[List[int]]]]: Inference result returned by one generation. |
||||
""" |
||||
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {} |
||||
prompts = [prompts] if isinstance(prompts, str) else prompts |
||||
request_ids = [request_ids] if isinstance(request_ids, int) else request_ids |
||||
|
||||
with torch.inference_mode(): |
||||
if prompts is not None or prompts_token_ids is not None: |
||||
self.add_request( |
||||
request_ids=request_ids, |
||||
prompts=prompts, |
||||
prompts_token_ids=prompts_token_ids, |
||||
**gen_config_dict, |
||||
) |
||||
|
||||
output_seqs_list = [] |
||||
total_tokens_list = [] |
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one. |
||||
if generation_config is not None: |
||||
self.generation_config = generation_config |
||||
self.generation_config_dict = gen_config_dict |
||||
|
||||
if self.use_spec_dec: |
||||
assert self.drafter is not None, "Drafter Model is not initialized." |
||||
while self.request_handler.check_unfinished_reqs(): |
||||
output_seqs_list += self.steps_spec_dec() |
||||
else: |
||||
while self.request_handler.check_unfinished_reqs(): |
||||
output_seqs_list += self.step() |
||||
|
||||
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id)) |
||||
|
||||
for seq in output_seqs_list: |
||||
total_tokens_list.append(seq.input_token_id + seq.output_token_id) |
||||
|
||||
output_str = self.tokenizer.batch_decode(total_tokens_list, skip_special_tokens=True) |
||||
|
||||
if return_token_ids: |
||||
output_tokens_list = [seq.output_token_id for seq in output_seqs_list] |
||||
return output_str, output_tokens_list |
||||
else: |
||||
return output_str |
||||
|
||||
@property |
||||
def has_prompt_template(self) -> bool: |
||||
""" """ |
||||
return self.inference_config.prompt_template is not None |
||||
|
||||
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]: |
||||
""" |
||||
This method will format the input prompt according to the prompt template given to the InferenceConfig. |
||||
""" |
||||
assert ( |
||||
self.has_prompt_template |
||||
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig." |
||||
|
||||
if isinstance(prompts, (list, tuple)): |
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] |
||||
elif isinstance(prompts, str): |
||||
return self.inference_config.prompt_template.format(input_text=prompts) |
||||
else: |
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") |
||||
|
||||
def add_request( |
||||
self, |
||||
request_ids: Union[List[int], int] = None, |
||||
prompts: Union[List[str], str] = None, |
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, |
||||
**kwargs, |
||||
) -> None: |
||||
""" |
||||
Add requests. |
||||
|
||||
Args: |
||||
request_ids (List[int], optional): The request ID. Defaults to None. |
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None. |
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None. |
||||
""" |
||||
|
||||
# apply the prompt template to the input prompts |
||||
|
||||
if self.has_prompt_template and prompts is not None: |
||||
prompts = self.format_prompt(prompts) |
||||
|
||||
block_size = self.inference_config.block_size |
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list): |
||||
request_ids = [request_ids] |
||||
|
||||
if prompts is not None and not isinstance(prompts, list): |
||||
prompts = [prompts] |
||||
|
||||
if prompts_token_ids is None: |
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." |
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[ |
||||
"input_ids" |
||||
] |
||||
|
||||
# list of torch Tensor |
||||
if isinstance(prompts_token_ids, list): |
||||
if isinstance(prompts_token_ids[0], torch.Tensor): |
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] |
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): |
||||
prompts_token_ids = prompts_token_ids.tolist() |
||||
else: |
||||
raise TypeError( |
||||
f"The dtype of prompts_token_ids must be one of list, torch.Tensor, np.ndarray, but got {type(prompts_token_ids)}." |
||||
) |
||||
|
||||
assert ( |
||||
len(prompts_token_ids[0]) <= self.inference_config.max_input_len |
||||
), f"The length of input prompts {len(prompts_token_ids[0])} must be less than max_input_len {self.inference_config.max_input_len}." |
||||
|
||||
prompts_num = len(prompts_token_ids) |
||||
|
||||
for i in range(prompts_num): |
||||
if request_ids: |
||||
assert isinstance( |
||||
request_ids[0], int |
||||
), f"The request_id type must be int, but got {type(request_ids[0])}" |
||||
assert len(request_ids) == prompts_num |
||||
request_id = request_ids[i] |
||||
else: |
||||
request_id = next(self.counter) |
||||
if prompts == None: |
||||
prompt = None |
||||
else: |
||||
prompt = prompts[i] |
||||
|
||||
max_length = kwargs.get("max_length", None) |
||||
max_new_tokens = kwargs.get("max_new_tokens", None) |
||||
if max_length is None and max_new_tokens is None: |
||||
max_new_tokens = self.generation_config.max_new_tokens or self.inference_config.max_output_len |
||||
elif max_length is not None: |
||||
max_new_tokens = max_length - len(prompts_token_ids[i]) |
||||
|
||||
if not self.inference_config.enable_streamingllm: |
||||
assert ( |
||||
self.inference_config.max_output_len >= max_new_tokens |
||||
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." |
||||
|
||||
sequence = Sequence( |
||||
request_id, |
||||
prompt, |
||||
prompts_token_ids[i], |
||||
block_size, |
||||
None, |
||||
self.tokenizer.eos_token_id, |
||||
self.tokenizer.pad_token_id, |
||||
max_output_len=max_new_tokens, |
||||
ignore_eos=self.inference_config.ignore_eos, |
||||
) |
||||
self.request_handler.add_sequence(sequence) |
||||
|
||||
def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor, InputMetaData]: |
||||
input_ids = batch.get_1D_inputs() |
||||
sequence_lengths = batch.get_sequence_lengths() |
||||
|
||||
if batch.is_prompts: |
||||
n_tokens = sequence_lengths.sum().item() |
||||
else: |
||||
n_tokens = batch.current_batch_size |
||||
if batch.use_spec_dec: |
||||
n_tokens = batch.num_tokens_to_verify + 1 |
||||
assert n_tokens == input_ids.size(0) |
||||
n_tokens = n_tokens * batch.current_batch_size |
||||
output_tensor = torch.zeros( |
||||
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device |
||||
) |
||||
|
||||
batch_token_ids = None |
||||
if ( |
||||
self.generation_config.repetition_penalty != 1.0 |
||||
or self.generation_config.no_repeat_ngram_size > 0 |
||||
or self.generation_config.forced_eos_token_id is not None |
||||
): |
||||
batch_token_ids = batch.batch_token_ids |
||||
|
||||
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference |
||||
use_cuda_graph = False |
||||
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): |
||||
use_cuda_graph = True |
||||
|
||||
input_meta_data = InputMetaData( |
||||
block_tables=batch.get_block_table_tensor(), |
||||
sequence_lengths=sequence_lengths, |
||||
fd_inter_tensor=batch.fd_inter_tensor, |
||||
batch_size=batch.current_batch_size, |
||||
is_prompts=batch.is_prompts, |
||||
use_cuda_kernel=self.inference_config.use_cuda_kernel, |
||||
use_cuda_graph=use_cuda_graph, |
||||
high_precision=self.high_precision, |
||||
kv_seq_len=sequence_lengths.max().item(), |
||||
head_dim=batch.head_dim, |
||||
dtype=batch.dtype, |
||||
use_spec_dec=batch.use_spec_dec, |
||||
num_tokens_to_verify=batch.num_tokens_to_verify, |
||||
batch_token_ids=batch_token_ids, |
||||
) |
||||
|
||||
return input_ids, output_tensor, input_meta_data |
||||
|
||||
def step(self) -> List[str]: |
||||
""" |
||||
In each step, do the follows: |
||||
1. Run RequestHandler.schedule() and get the batch used for inference. |
||||
2. Get the input, inputinfo and output placeholder from the batchbucket |
||||
3. Run model to generate the next token |
||||
4. Update waiting list and running list in RequestHandler and get finished sequences. |
||||
5. Decode and return finished sequences. |
||||
|
||||
Returns: |
||||
List[str]: Decoded finished sequences generated by one step. |
||||
""" |
||||
|
||||
batch = self.request_handler.schedule() |
||||
|
||||
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) |
||||
|
||||
if input_meta_data.use_cuda_graph: |
||||
model_executable = self.graph_runners[input_meta_data.batch_size] |
||||
else: |
||||
model_executable = self.model |
||||
|
||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. |
||||
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) |
||||
if self.inference_config.pad_input: |
||||
logits = logits[:, -1, :] |
||||
|
||||
if self.inference_config.enable_streamingllm: |
||||
updated_block_ids = batch.streamingllm_update_batch( |
||||
self.inference_config.start_token_size, self.inference_config.generated_token_size |
||||
) |
||||
self.request_handler.streamingllm_free_block_tables(updated_block_ids) |
||||
|
||||
next_tokens = search_tokens( |
||||
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids |
||||
) |
||||
self.request_handler.append_next_tokens(next_tokens) |
||||
finished_sequences = self.request_handler.update() |
||||
|
||||
return finished_sequences |
@ -0,0 +1,54 @@
|
||||
import inspect |
||||
import types |
||||
|
||||
import torch |
||||
from torch import nn |
||||
|
||||
|
||||
class DiffusionPipe(nn.Module): |
||||
""" |
||||
This Class convert a class of `DiffusionPipeline` into `nn.Module` and reserve most of origin attr,function and property. |
||||
""" |
||||
|
||||
def __init__(self, source_obj) -> None: |
||||
super(DiffusionPipe, self).__init__() |
||||
|
||||
for k, v in source_obj.__dict__.items(): |
||||
if isinstance(v, nn.Module): |
||||
self.add_module(k, v) |
||||
else: |
||||
setattr(self, k, v) |
||||
|
||||
skip_list = ["_execution_device", "to", "device"] # this |
||||
|
||||
for name, member in inspect.getmembers(source_obj.__class__): |
||||
if name in skip_list: |
||||
continue |
||||
if not name.startswith("__") and not name.endswith("__"): |
||||
if isinstance(member, property): |
||||
setattr(self.__class__, name, member) |
||||
elif inspect.isfunction(member) or inspect.ismethod(member): |
||||
bound_method = types.MethodType(member, self) |
||||
setattr(self, name, bound_method) |
||||
elif not callable(member) and not isinstance(member, property): |
||||
setattr(self, name, member) |
||||
elif name == "__call__": |
||||
bound_method = types.MethodType(member, self) |
||||
setattr(self, "_forward", bound_method) |
||||
|
||||
@property |
||||
def _execution_device(self): |
||||
r""" |
||||
Returns the device on which the pipeline's models will be executed. After calling |
||||
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from |
||||
Accelerate's module hooks. |
||||
""" |
||||
# return self.device |
||||
return torch.device("cuda") |
||||
|
||||
@property |
||||
def device(self): |
||||
next(self.parameters()).device |
||||
|
||||
def forward(self, *args, **kwargs): |
||||
return self._forward(*args, **kwargs) |
@ -0,0 +1,220 @@
|
||||
# Code adapted from: |
||||
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py |
||||
|
||||
from typing import Callable, List, Optional, Union |
||||
|
||||
import PIL.Image |
||||
import torch |
||||
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import ( |
||||
ASPECT_RATIO_256_BIN, |
||||
ASPECT_RATIO_512_BIN, |
||||
ASPECT_RATIO_1024_BIN, |
||||
) |
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps |
||||
|
||||
from colossalai.logging import get_dist_logger |
||||
|
||||
from .diffusion import DiffusionPipe |
||||
|
||||
logger = get_dist_logger(__name__) |
||||
|
||||
|
||||
@torch.no_grad() |
||||
def pixart_alpha_forward( |
||||
self: DiffusionPipe, |
||||
prompt: Union[str, List[str]] = None, |
||||
negative_prompt: str = "", |
||||
num_inference_steps: int = 20, |
||||
timesteps: List[int] = None, |
||||
sigmas: List[float] = None, |
||||
guidance_scale: float = 4.5, |
||||
num_images_per_prompt: Optional[int] = 1, |
||||
height: Optional[int] = None, |
||||
width: Optional[int] = None, |
||||
eta: float = 0.0, |
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
||||
latents: Optional[torch.Tensor] = None, |
||||
prompt_embeds: Optional[torch.Tensor] = None, |
||||
prompt_attention_mask: Optional[torch.Tensor] = None, |
||||
negative_prompt_embeds: Optional[torch.Tensor] = None, |
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
||||
output_type: Optional[str] = "pil", |
||||
return_dict: bool = True, |
||||
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, |
||||
callback_steps: int = 1, |
||||
clean_caption: bool = True, |
||||
use_resolution_binning: bool = True, |
||||
max_sequence_length: int = 120, |
||||
**kwargs, |
||||
) -> PIL.Image: |
||||
# 1. Check inputs. Raise error if not correct |
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor |
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor |
||||
if use_resolution_binning: |
||||
if self.transformer.config.sample_size == 128: |
||||
aspect_ratio_bin = ASPECT_RATIO_1024_BIN |
||||
elif self.transformer.config.sample_size == 64: |
||||
aspect_ratio_bin = ASPECT_RATIO_512_BIN |
||||
elif self.transformer.config.sample_size == 32: |
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN |
||||
else: |
||||
raise ValueError("Invalid sample size") |
||||
orig_height, orig_width = height, width |
||||
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) |
||||
|
||||
self.check_inputs( |
||||
prompt, |
||||
height, |
||||
width, |
||||
negative_prompt, |
||||
callback_steps, |
||||
prompt_embeds, |
||||
negative_prompt_embeds, |
||||
prompt_attention_mask, |
||||
negative_prompt_attention_mask, |
||||
) |
||||
|
||||
# 2. Default height and width to transformer |
||||
if prompt is not None and isinstance(prompt, str): |
||||
batch_size = 1 |
||||
elif prompt is not None and isinstance(prompt, list): |
||||
batch_size = len(prompt) |
||||
else: |
||||
batch_size = prompt_embeds.shape[0] |
||||
|
||||
device = self._execution_device |
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
||||
# corresponds to doing no classifier free guidance. |
||||
do_classifier_free_guidance = guidance_scale > 1.0 |
||||
|
||||
# 3. Encode input prompt |
||||
( |
||||
prompt_embeds, |
||||
prompt_attention_mask, |
||||
negative_prompt_embeds, |
||||
negative_prompt_attention_mask, |
||||
) = self.encode_prompt( |
||||
prompt, |
||||
do_classifier_free_guidance, |
||||
negative_prompt=negative_prompt, |
||||
num_images_per_prompt=num_images_per_prompt, |
||||
device=device, |
||||
prompt_embeds=prompt_embeds, |
||||
negative_prompt_embeds=negative_prompt_embeds, |
||||
prompt_attention_mask=prompt_attention_mask, |
||||
negative_prompt_attention_mask=negative_prompt_attention_mask, |
||||
clean_caption=clean_caption, |
||||
max_sequence_length=max_sequence_length, |
||||
) |
||||
if do_classifier_free_guidance: |
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) |
||||
|
||||
# 4. Prepare timesteps |
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas) |
||||
|
||||
# 5. Prepare latents. |
||||
latent_channels = self.transformer.config.in_channels |
||||
latents = self.prepare_latents( |
||||
batch_size * num_images_per_prompt, |
||||
latent_channels, |
||||
height, |
||||
width, |
||||
prompt_embeds.dtype, |
||||
device, |
||||
generator, |
||||
latents, |
||||
) |
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
||||
|
||||
# 6.1 Prepare micro-conditions. |
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None} |
||||
if self.transformer.config.sample_size == 128: |
||||
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) |
||||
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) |
||||
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) |
||||
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) |
||||
|
||||
if do_classifier_free_guidance: |
||||
resolution = torch.cat([resolution, resolution], dim=0) |
||||
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) |
||||
|
||||
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} |
||||
|
||||
# 7. Denoising loop |
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar: |
||||
for i, t in enumerate(timesteps): |
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
||||
|
||||
current_timestep = t |
||||
if not torch.is_tensor(current_timestep): |
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can |
||||
# This would be a good case for the `match` statement (Python 3.10+) |
||||
is_mps = latent_model_input.device.type == "mps" |
||||
if isinstance(current_timestep, float): |
||||
dtype = torch.float32 if is_mps else torch.float64 |
||||
else: |
||||
dtype = torch.int32 if is_mps else torch.int64 |
||||
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) |
||||
elif len(current_timestep.shape) == 0: |
||||
current_timestep = current_timestep[None].to(latent_model_input.device) |
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML |
||||
current_timestep = current_timestep.expand(latent_model_input.shape[0]) |
||||
|
||||
# predict noise model_output |
||||
noise_pred = self.transformer( |
||||
latent_model_input, |
||||
encoder_hidden_states=prompt_embeds, |
||||
encoder_attention_mask=prompt_attention_mask, |
||||
timestep=current_timestep, |
||||
added_cond_kwargs=added_cond_kwargs, |
||||
return_dict=False, |
||||
)[0] |
||||
|
||||
# perform guidance |
||||
if do_classifier_free_guidance: |
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
||||
|
||||
# learned sigma |
||||
if self.transformer.config.out_channels // 2 == latent_channels: |
||||
noise_pred = noise_pred.chunk(2, dim=1)[0] |
||||
else: |
||||
noise_pred = noise_pred |
||||
|
||||
# compute previous image: x_t -> x_t-1 |
||||
if num_inference_steps == 1: |
||||
# For DMD one step sampling: https://arxiv.org/abs/2311.18828 |
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample |
||||
else: |
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
||||
|
||||
# call the callback, if provided |
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
||||
progress_bar.update() |
||||
if callback is not None and i % callback_steps == 0: |
||||
step_idx = i // getattr(self.scheduler, "order", 1) |
||||
callback(step_idx, t, latents) |
||||
|
||||
output_type = "pil" # TODO(@lry89757) temporarily image, please support more return output |
||||
if not output_type == "latent": |
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
||||
if use_resolution_binning: |
||||
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) |
||||
else: |
||||
image = latents |
||||
|
||||
if not output_type == "latent": |
||||
image = self.image_processor.postprocess(image, output_type=output_type) |
||||
|
||||
# Offload all models |
||||
# self.maybe_free_model_hooks() |
||||
|
||||
return image |
@ -0,0 +1,178 @@
|
||||
# This code is adapted from huggingface diffusers: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py |
||||
from typing import Any, Callable, Dict, List, Optional, Union |
||||
|
||||
import torch |
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps |
||||
|
||||
from .diffusion import DiffusionPipe |
||||
|
||||
|
||||
# TODO(@lry89757) temporarily image, please support more return output |
||||
@torch.no_grad() |
||||
def sd3_forward( |
||||
self: DiffusionPipe, |
||||
prompt: Union[str, List[str]] = None, |
||||
prompt_2: Optional[Union[str, List[str]]] = None, |
||||
prompt_3: Optional[Union[str, List[str]]] = None, |
||||
height: Optional[int] = None, |
||||
width: Optional[int] = None, |
||||
num_inference_steps: int = 28, |
||||
timesteps: List[int] = None, |
||||
guidance_scale: float = 7.0, |
||||
negative_prompt: Optional[Union[str, List[str]]] = None, |
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None, |
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None, |
||||
num_images_per_prompt: Optional[int] = 1, |
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
||||
latents: Optional[torch.FloatTensor] = None, |
||||
prompt_embeds: Optional[torch.FloatTensor] = None, |
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
||||
output_type: Optional[str] = "pil", |
||||
return_dict: bool = True, |
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
||||
clip_skip: Optional[int] = None, |
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
||||
): |
||||
height = height or self.default_sample_size * self.vae_scale_factor |
||||
width = width or self.default_sample_size * self.vae_scale_factor |
||||
|
||||
# 1. Check inputs. Raise error if not correct |
||||
self.check_inputs( |
||||
prompt, |
||||
prompt_2, |
||||
prompt_3, |
||||
height, |
||||
width, |
||||
negative_prompt=negative_prompt, |
||||
negative_prompt_2=negative_prompt_2, |
||||
negative_prompt_3=negative_prompt_3, |
||||
prompt_embeds=prompt_embeds, |
||||
negative_prompt_embeds=negative_prompt_embeds, |
||||
pooled_prompt_embeds=pooled_prompt_embeds, |
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
||||
) |
||||
|
||||
self._guidance_scale = guidance_scale |
||||
self._clip_skip = clip_skip |
||||
self._joint_attention_kwargs = joint_attention_kwargs |
||||
self._interrupt = False |
||||
|
||||
# 2. Define call parameters |
||||
if prompt is not None and isinstance(prompt, str): |
||||
batch_size = 1 |
||||
elif prompt is not None and isinstance(prompt, list): |
||||
batch_size = len(prompt) |
||||
else: |
||||
batch_size = prompt_embeds.shape[0] |
||||
|
||||
device = self._execution_device |
||||
|
||||
( |
||||
prompt_embeds, |
||||
negative_prompt_embeds, |
||||
pooled_prompt_embeds, |
||||
negative_pooled_prompt_embeds, |
||||
) = self.encode_prompt( |
||||
prompt=prompt, |
||||
prompt_2=prompt_2, |
||||
prompt_3=prompt_3, |
||||
negative_prompt=negative_prompt, |
||||
negative_prompt_2=negative_prompt_2, |
||||
negative_prompt_3=negative_prompt_3, |
||||
do_classifier_free_guidance=self.do_classifier_free_guidance, |
||||
prompt_embeds=prompt_embeds, |
||||
negative_prompt_embeds=negative_prompt_embeds, |
||||
pooled_prompt_embeds=pooled_prompt_embeds, |
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
||||
device=device, |
||||
clip_skip=self.clip_skip, |
||||
num_images_per_prompt=num_images_per_prompt, |
||||
) |
||||
|
||||
if self.do_classifier_free_guidance: |
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) |
||||
|
||||
# 4. Prepare timesteps |
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) |
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
||||
self._num_timesteps = len(timesteps) |
||||
|
||||
# 5. Prepare latent variables |
||||
num_channels_latents = self.transformer.config.in_channels |
||||
latents = self.prepare_latents( |
||||
batch_size * num_images_per_prompt, |
||||
num_channels_latents, |
||||
height, |
||||
width, |
||||
prompt_embeds.dtype, |
||||
device, |
||||
generator, |
||||
latents, |
||||
) |
||||
|
||||
# 6. Denoising loop |
||||
with self.progress_bar(total=num_inference_steps) as progress_bar: |
||||
for i, t in enumerate(timesteps): |
||||
if self.interrupt: |
||||
continue |
||||
|
||||
# expand the latents if we are doing classifier free guidance |
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML |
||||
timestep = t.expand(latent_model_input.shape[0]) |
||||
|
||||
noise_pred = self.transformer( |
||||
hidden_states=latent_model_input, |
||||
timestep=timestep, |
||||
encoder_hidden_states=prompt_embeds, |
||||
pooled_projections=pooled_prompt_embeds, |
||||
joint_attention_kwargs=self.joint_attention_kwargs, |
||||
return_dict=False, |
||||
)[0] |
||||
|
||||
# perform guidance |
||||
if self.do_classifier_free_guidance: |
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1 |
||||
latents_dtype = latents.dtype |
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
||||
|
||||
if latents.dtype != latents_dtype: |
||||
if torch.backends.mps.is_available(): |
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 |
||||
latents = latents.to(latents_dtype) |
||||
|
||||
if callback_on_step_end is not None: |
||||
callback_kwargs = {} |
||||
for k in callback_on_step_end_tensor_inputs: |
||||
callback_kwargs[k] = locals()[k] |
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
||||
|
||||
latents = callback_outputs.pop("latents", latents) |
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
||||
negative_pooled_prompt_embeds = callback_outputs.pop( |
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds |
||||
) |
||||
|
||||
# call the callback, if provided |
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
||||
progress_bar.update() |
||||
|
||||
if output_type == "latent": |
||||
image = latents |
||||
|
||||
else: |
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0] |
||||
image = self.image_processor.postprocess(image, output_type=output_type) |
||||
|
||||
return image |
@ -1,16 +1,22 @@
|
||||
from .glide_llama import GlideLlamaModelPolicy |
||||
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy |
||||
from .nopadding_llama import NoPaddingLlamaModelInferPolicy |
||||
from .pixart_alpha import PixArtAlphaInferPolicy |
||||
from .stablediffusion3 import StableDiffusion3InferPolicy |
||||
|
||||
model_policy_map = { |
||||
"nopadding_llama": NoPaddingLlamaModelInferPolicy, |
||||
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy, |
||||
"glide_llama": GlideLlamaModelPolicy, |
||||
"StableDiffusion3Pipeline": StableDiffusion3InferPolicy, |
||||
"PixArtAlphaPipeline": PixArtAlphaInferPolicy, |
||||
} |
||||
|
||||
__all__ = [ |
||||
"NoPaddingLlamaModelInferPolicy", |
||||
"NoPaddingBaichuanModelInferPolicy", |
||||
"GlideLlamaModelPolicy", |
||||
"StableDiffusion3InferPolicy", |
||||
"PixArtAlphaInferPolicy", |
||||
"model_polic_map", |
||||
] |
||||
|
@ -0,0 +1,34 @@
|
||||
from torch import nn |
||||
|
||||
from colossalai.inference.config import RPC_PARAM |
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe |
||||
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward |
||||
from colossalai.shardformer.policies.base_policy import Policy |
||||
|
||||
|
||||
class PixArtAlphaInferPolicy(Policy, RPC_PARAM): |
||||
def __init__(self) -> None: |
||||
super().__init__() |
||||
|
||||
def module_policy(self): |
||||
policy = {} |
||||
self.append_or_create_method_replacement( |
||||
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe |
||||
) |
||||
return policy |
||||
|
||||
def preprocess(self) -> nn.Module: |
||||
return self.model |
||||
|
||||
def postprocess(self): |
||||
return self.model |
||||
|
||||
def config_sanity_check(self): |
||||
pass |
||||
|
||||
def to_rpc_param(self) -> str: |
||||
return __class__.__name__ |
||||
|
||||
@staticmethod |
||||
def from_rpc_param() -> "PixArtAlphaInferPolicy": |
||||
return PixArtAlphaInferPolicy() |
@ -0,0 +1,34 @@
|
||||
from torch import nn |
||||
|
||||
from colossalai.inference.config import RPC_PARAM |
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe |
||||
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward |
||||
from colossalai.shardformer.policies.base_policy import Policy |
||||
|
||||
|
||||
class StableDiffusion3InferPolicy(Policy, RPC_PARAM): |
||||
def __init__(self) -> None: |
||||
super().__init__() |
||||
|
||||
def module_policy(self): |
||||
policy = {} |
||||
self.append_or_create_method_replacement( |
||||
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe |
||||
) |
||||
return policy |
||||
|
||||
def preprocess(self) -> nn.Module: |
||||
return self.model |
||||
|
||||
def postprocess(self): |
||||
return self.model |
||||
|
||||
def config_sanity_check(self): |
||||
pass |
||||
|
||||
def to_rpc_param(self) -> str: |
||||
return __class__.__name__ |
||||
|
||||
@staticmethod |
||||
def from_rpc_param() -> "StableDiffusion3InferPolicy": |
||||
return StableDiffusion3InferPolicy() |
@ -0,0 +1,75 @@
|
||||
import argparse |
||||
|
||||
from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline |
||||
from torch import bfloat16, float16, float32 |
||||
|
||||
import colossalai |
||||
from colossalai.cluster import DistCoordinator |
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig |
||||
from colossalai.inference.core.engine import InferenceEngine |
||||
from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy |
||||
from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy |
||||
|
||||
# For Stable Diffusion 3, we'll use the following configuration |
||||
MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0] |
||||
POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0] |
||||
|
||||
TORCH_DTYPE_MAP = { |
||||
"fp16": float16, |
||||
"fp32": float32, |
||||
"bf16": bfloat16, |
||||
} |
||||
|
||||
|
||||
def infer(args): |
||||
# ============================== |
||||
# Launch colossalai, setup distributed environment |
||||
# ============================== |
||||
colossalai.launch_from_torch() |
||||
coordinator = DistCoordinator() |
||||
|
||||
# ============================== |
||||
# Load model and tokenizer |
||||
# ============================== |
||||
model_path_or_name = args.model |
||||
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None)) |
||||
|
||||
# ============================== |
||||
# Initialize InferenceEngine |
||||
# ============================== |
||||
coordinator.print_on_master(f"Initializing Inference Engine...") |
||||
inference_config = InferenceConfig( |
||||
dtype=args.dtype, |
||||
max_batch_size=args.max_batch_size, |
||||
tp_size=args.tp_size, |
||||
use_cuda_kernel=args.use_cuda_kernel, |
||||
) |
||||
engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True) |
||||
|
||||
# ============================== |
||||
# Generation |
||||
# ============================== |
||||
coordinator.print_on_master(f"Generating...") |
||||
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0] |
||||
out.save("cat.jpg") |
||||
coordinator.print_on_master(out) |
||||
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH |
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1 |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
# ============================== |
||||
# Parse Arguments |
||||
# ============================== |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name") |
||||
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size") |
||||
parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt") |
||||
parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size") |
||||
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"]) |
||||
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default") |
||||
args = parser.parse_args() |
||||
|
||||
infer(args) |
Loading…
Reference in new issue