mirror of https://github.com/hpcaitech/ColossalAI
[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source * comp comm overlap optimization * sd3 benchmark * pixart distrifusion bug fix * sd3 bug fix and benchmark * generation bug fix * naming fix * add docstring, fix counter and shape error * add reference * readme and requirementpull/5951/head
parent
7b38964e3a
commit
bcf0181ecd
|
@ -18,7 +18,7 @@
|
|||
|
||||
|
||||
## 📌 Introduction
|
||||
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
|
||||
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/colossal-inference-v1-1.png" width=1000/>
|
||||
|
@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
|
|||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
# Distrifusion
|
||||
@InProceedings{Li_2024_CVPR,
|
||||
author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
|
||||
title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
month={June},
|
||||
year={2024},
|
||||
pages={7183-7193}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
|
|||
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
|
||||
start_token_size(int): The size of the start tokens, when using StreamingLLM.
|
||||
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
|
||||
patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
|
||||
"""
|
||||
|
||||
# NOTE: arrange configs according to their importance and frequency of usage
|
||||
|
@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
|
|||
start_token_size: int = 4
|
||||
generated_token_size: int = 512
|
||||
|
||||
# Acceleration for Diffusion Model(PipeFusion or Distrifusion)
|
||||
patched_parallelism_size: int = 1 # for distrifusion
|
||||
# pipeFusion_m_size: int = 1 # for pipefusion
|
||||
# pipeFusion_n_size: int = 1 # for pipefusion
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
|
||||
self._verify_config()
|
||||
|
@ -288,6 +294,14 @@ class InferenceConfig(RPC_PARAM):
|
|||
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
|
||||
self.start_token_size = self.block_size
|
||||
|
||||
# check Distrifusion
|
||||
# TODO(@lry89757) need more detailed check
|
||||
if self.patched_parallelism_size > 1:
|
||||
# self.use_patched_parallelism = True
|
||||
self.tp_size = (
|
||||
self.patched_parallelism_size
|
||||
) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
|
||||
|
||||
# check prompt template
|
||||
if self.prompt_template is None:
|
||||
return
|
||||
|
@ -324,6 +338,7 @@ class InferenceConfig(RPC_PARAM):
|
|||
use_cuda_kernel=self.use_cuda_kernel,
|
||||
use_spec_dec=self.use_spec_dec,
|
||||
use_flash_attn=use_flash_attn,
|
||||
patched_parallelism_size=self.patched_parallelism_size,
|
||||
)
|
||||
return model_inference_config
|
||||
|
||||
|
@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
|
|||
use_cuda_kernel: bool = False
|
||||
use_spec_dec: bool = False
|
||||
use_flash_attn: bool = False
|
||||
patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch import distributed as dist
|
|||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
|
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.struct import DiffusionSequence
|
||||
from colossalai.inference.utils import get_model_size, get_model_type
|
||||
|
|
|
@ -0,0 +1,626 @@
|
|||
# Code refer and adapted from:
|
||||
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers
|
||||
# https://github.com/PipeFusion/PipeFusion
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models import attention_processor
|
||||
from diffusers.models.attention import Attention
|
||||
from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
|
||||
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from torch import nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
HAS_FLASH_ATTN = False
|
||||
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py
|
||||
def PixArtAlphaTransformer2DModel_forward(
|
||||
self: PixArtTransformer2DModel,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
assert hasattr(
|
||||
self, "patched_parallel_size"
|
||||
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
|
||||
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
batch_size = hidden_states.shape[0]
|
||||
height, width = (
|
||||
hidden_states.shape[-2] // self.config.patch_size,
|
||||
hidden_states.shape[-1] // self.config.patch_size,
|
||||
)
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
if self.caption_projection is not None:
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(
|
||||
2, dim=1
|
||||
)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(
|
||||
-1,
|
||||
height // self.patched_parallel_size,
|
||||
width,
|
||||
self.config.patch_size,
|
||||
self.config.patch_size,
|
||||
self.out_channels,
|
||||
)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(
|
||||
-1,
|
||||
self.out_channels,
|
||||
height // self.patched_parallel_size * self.config.patch_size,
|
||||
width * self.config.patch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# enable Distrifusion Optimization
|
||||
if hasattr(self, "patched_parallel_size"):
|
||||
from torch import distributed as dist
|
||||
|
||||
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
|
||||
self.output_buffer = torch.empty_like(output)
|
||||
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
|
||||
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
|
||||
output = output.contiguous()
|
||||
dist.all_gather(self.buffer_list, output, async_op=False)
|
||||
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
|
||||
output = self.output_buffer
|
||||
|
||||
return (output,)
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py
|
||||
def SD3Transformer2DModel_forward(
|
||||
self: SD3Transformer2DModel,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor]:
|
||||
|
||||
assert hasattr(
|
||||
self, "patched_parallel_size"
|
||||
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
patch_size = self.config.patch_size
|
||||
height = height // patch_size // self.patched_parallel_size
|
||||
width = width // patch_size
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
# enable Distrifusion Optimization
|
||||
if hasattr(self, "patched_parallel_size"):
|
||||
from torch import distributed as dist
|
||||
|
||||
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
|
||||
self.output_buffer = torch.empty_like(output)
|
||||
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
|
||||
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
|
||||
output = output.contiguous()
|
||||
dist.all_gather(self.buffer_list, output, async_op=False)
|
||||
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
|
||||
output = self.output_buffer
|
||||
|
||||
return (output,)
|
||||
|
||||
|
||||
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py
|
||||
class DistrifusionPatchEmbed(ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
module: PatchEmbed,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.rank = dist.get_rank(group=process_group)
|
||||
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
distrifusion_embed = DistrifusionPatchEmbed(
|
||||
module, process_group, model_shard_infer_config=model_shard_infer_config
|
||||
)
|
||||
return distrifusion_embed
|
||||
|
||||
def forward(self, latent):
|
||||
module = self.module
|
||||
if module.pos_embed_max_size is not None:
|
||||
height, width = latent.shape[-2:]
|
||||
else:
|
||||
height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size
|
||||
|
||||
latent = module.proj(latent)
|
||||
if module.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
if module.layer_norm:
|
||||
latent = module.norm(latent)
|
||||
if module.pos_embed is None:
|
||||
return latent.to(latent.dtype)
|
||||
# Interpolate or crop positional embeddings as needed
|
||||
if module.pos_embed_max_size:
|
||||
pos_embed = module.cropped_pos_embed(height, width)
|
||||
else:
|
||||
if module.height != height or module.width != width:
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
embed_dim=module.pos_embed.shape[-1],
|
||||
grid_size=(height, width),
|
||||
base_size=module.base_size,
|
||||
interpolation_scale=module.interpolation_scale,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
|
||||
else:
|
||||
pos_embed = module.pos_embed
|
||||
|
||||
b, c, h = pos_embed.shape
|
||||
pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank]
|
||||
|
||||
return (latent + pos_embed).to(latent.dtype)
|
||||
|
||||
|
||||
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py
|
||||
class DistrifusionConv2D(ParallelModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Conv2d,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.rank = dist.get_rank(group=process_group)
|
||||
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config)
|
||||
return distrifusion_conv
|
||||
|
||||
def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
b, c, h, w = x.shape
|
||||
|
||||
stride = self.module.stride[0]
|
||||
padding = self.module.padding[0]
|
||||
|
||||
output_h = x.shape[2] // stride // self.patched_parallelism_size
|
||||
idx = dist.get_rank()
|
||||
h_begin = output_h * idx * stride - padding
|
||||
h_end = output_h * (idx + 1) * stride + padding
|
||||
final_padding = [padding, padding, 0, 0]
|
||||
if h_begin < 0:
|
||||
h_begin = 0
|
||||
final_padding[2] = padding
|
||||
if h_end > h:
|
||||
h_end = h
|
||||
final_padding[3] = padding
|
||||
sliced_input = x[:, :, h_begin:h_end, :]
|
||||
padded_input = F.pad(sliced_input, final_padding, mode="constant")
|
||||
return F.conv2d(
|
||||
padded_input,
|
||||
self.module.weight,
|
||||
self.module.bias,
|
||||
stride=stride,
|
||||
padding="valid",
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
output = self.sliced_forward(input)
|
||||
return output
|
||||
|
||||
|
||||
# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py
|
||||
class DistrifusionFusedAttention(ParallelModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: attention_processor.Attention,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.counter = 0
|
||||
self.module = module
|
||||
self.buffer_list = None
|
||||
self.kv_buffer_idx = dist.get_rank(group=process_group)
|
||||
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
||||
self.handle = None
|
||||
self.process_group = process_group
|
||||
self.warm_step = 5 # for warmup
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
return DistrifusionFusedAttention(
|
||||
module=module,
|
||||
process_group=process_group,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
|
||||
|
||||
if self.patched_parallelism_size == 1:
|
||||
full_kv = kv
|
||||
else:
|
||||
if self.buffer_list is None: # buffer not created
|
||||
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
|
||||
elif self.counter <= self.warm_step:
|
||||
# logger.info(f"warmup: {self.counter}")
|
||||
dist.all_gather(
|
||||
self.buffer_list,
|
||||
kv,
|
||||
group=self.process_group,
|
||||
async_op=False,
|
||||
)
|
||||
full_kv = torch.cat(self.buffer_list, dim=1)
|
||||
else:
|
||||
# logger.info(f"use old kv to infer: {self.counter}")
|
||||
self.buffer_list[self.kv_buffer_idx].copy_(kv)
|
||||
full_kv = torch.cat(self.buffer_list, dim=1)
|
||||
assert self.handle is None, "we should maintain the kv of last step"
|
||||
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
|
||||
|
||||
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# attention
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, is_causal=False
|
||||
) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : residual.shape[1]],
|
||||
hidden_states[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.handle is not None:
|
||||
self.handle.wait()
|
||||
self.handle = None
|
||||
|
||||
b, l, c = hidden_states.shape
|
||||
kv_shape = (b, l, self.module.to_k.out_features * 2)
|
||||
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
|
||||
|
||||
self.buffer_list = [
|
||||
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
|
||||
for _ in range(self.patched_parallelism_size)
|
||||
]
|
||||
|
||||
self.counter = 0
|
||||
|
||||
attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks"}
|
||||
unused_kwargs = [
|
||||
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
||||
]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
||||
|
||||
output = self._forward(
|
||||
self.module,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
self.counter += 1
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py
|
||||
class DistriSelfAttention(ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
module: Attention,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
model_shard_infer_config: ModelShardInferenceConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.counter = 0
|
||||
self.module = module
|
||||
self.buffer_list = None
|
||||
self.kv_buffer_idx = dist.get_rank(group=process_group)
|
||||
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
||||
self.handle = None
|
||||
self.process_group = process_group
|
||||
self.warm_step = 3 # for warmup
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
||||
return DistriSelfAttention(
|
||||
module=module,
|
||||
process_group=process_group,
|
||||
model_shard_infer_config=model_shard_infer_config,
|
||||
)
|
||||
|
||||
def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
|
||||
attn = self.module
|
||||
assert isinstance(attn, Attention)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states
|
||||
k = self.module.to_k(encoder_hidden_states)
|
||||
v = self.module.to_v(encoder_hidden_states)
|
||||
kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
|
||||
|
||||
if self.patched_parallelism_size == 1:
|
||||
full_kv = kv
|
||||
else:
|
||||
if self.buffer_list is None: # buffer not created
|
||||
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
|
||||
elif self.counter <= self.warm_step:
|
||||
# logger.info(f"warmup: {self.counter}")
|
||||
dist.all_gather(
|
||||
self.buffer_list,
|
||||
kv,
|
||||
group=self.process_group,
|
||||
async_op=False,
|
||||
)
|
||||
full_kv = torch.cat(self.buffer_list, dim=1)
|
||||
else:
|
||||
# logger.info(f"use old kv to infer: {self.counter}")
|
||||
self.buffer_list[self.kv_buffer_idx].copy_(kv)
|
||||
full_kv = torch.cat(self.buffer_list, dim=1)
|
||||
assert self.handle is None, "we should maintain the kv of last step"
|
||||
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
# flash attn
|
||||
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
|
||||
else:
|
||||
# naive attn
|
||||
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
# async preallocates memo buffer
|
||||
if self.handle is not None:
|
||||
self.handle.wait()
|
||||
self.handle = None
|
||||
|
||||
b, l, c = hidden_states.shape
|
||||
kv_shape = (b, l, self.module.to_k.out_features * 2)
|
||||
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
|
||||
|
||||
self.buffer_list = [
|
||||
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
|
||||
for _ in range(self.patched_parallelism_size)
|
||||
]
|
||||
|
||||
self.counter = 0
|
||||
|
||||
output = self._forward(hidden_states, scale=scale)
|
||||
|
||||
self.counter += 1
|
||||
return output
|
|
@ -14,7 +14,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retri
|
|||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .diffusion import DiffusionPipe
|
||||
from ..layers.diffusion import DiffusionPipe
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||
import torch
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
||||
|
||||
from .diffusion import DiffusionPipe
|
||||
from ..layers.diffusion import DiffusionPipe
|
||||
|
||||
|
||||
# TODO(@lry89757) temporarily image, please support more return output
|
||||
|
|
|
@ -1,9 +1,17 @@
|
|||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
|
||||
from torch import nn
|
||||
|
||||
from colossalai.inference.config import RPC_PARAM
|
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.layers.distrifusion import (
|
||||
DistrifusionConv2D,
|
||||
DistrifusionPatchEmbed,
|
||||
DistriSelfAttention,
|
||||
PixArtAlphaTransformer2DModel_forward,
|
||||
)
|
||||
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
|
||||
|
@ -12,9 +20,46 @@ class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
|
|||
|
||||
def module_policy(self):
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
|
||||
|
||||
policy[PixArtTransformer2DModel] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="pos_embed.proj",
|
||||
target_module=DistrifusionConv2D,
|
||||
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="pos_embed",
|
||||
target_module=DistrifusionPatchEmbed,
|
||||
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
|
||||
),
|
||||
],
|
||||
attribute_replacement={
|
||||
"patched_parallel_size": self.shard_config.extra_kwargs[
|
||||
"model_shard_infer_config"
|
||||
].patched_parallelism_size
|
||||
},
|
||||
method_replacement={"forward": PixArtAlphaTransformer2DModel_forward},
|
||||
)
|
||||
|
||||
policy[BasicTransformerBlock] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn1",
|
||||
target_module=DistriSelfAttention,
|
||||
kwargs={
|
||||
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def preprocess(self) -> nn.Module:
|
||||
|
|
|
@ -1,9 +1,17 @@
|
|||
from diffusers.models.attention import JointTransformerBlock
|
||||
from diffusers.models.transformers import SD3Transformer2DModel
|
||||
from torch import nn
|
||||
|
||||
from colossalai.inference.config import RPC_PARAM
|
||||
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
|
||||
from colossalai.inference.modeling.layers.distrifusion import (
|
||||
DistrifusionConv2D,
|
||||
DistrifusionFusedAttention,
|
||||
DistrifusionPatchEmbed,
|
||||
SD3Transformer2DModel_forward,
|
||||
)
|
||||
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
|
||||
|
@ -12,6 +20,42 @@ class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
|
|||
|
||||
def module_policy(self):
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
|
||||
|
||||
policy[SD3Transformer2DModel] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="pos_embed.proj",
|
||||
target_module=DistrifusionConv2D,
|
||||
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="pos_embed",
|
||||
target_module=DistrifusionPatchEmbed,
|
||||
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
|
||||
),
|
||||
],
|
||||
attribute_replacement={
|
||||
"patched_parallel_size": self.shard_config.extra_kwargs[
|
||||
"model_shard_infer_config"
|
||||
].patched_parallelism_size
|
||||
},
|
||||
method_replacement={"forward": SD3Transformer2DModel_forward},
|
||||
)
|
||||
|
||||
policy[JointTransformerBlock] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn",
|
||||
target_module=DistrifusionFusedAttention,
|
||||
kwargs={
|
||||
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
|
||||
)
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
## File Structure
|
||||
```
|
||||
|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model.
|
||||
|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion
|
||||
|- benchmark_sd3.py: benchmark the performance of our InferenceEngine
|
||||
|- run_benchmark.sh: run benchmark command
|
||||
```
|
||||
Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/`
|
||||
|
||||
## Run Inference
|
||||
|
||||
The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3.
|
||||
|
||||
For a basic setting, you could run the example by:
|
||||
```bash
|
||||
colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world"
|
||||
```
|
||||
|
||||
Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs:
|
||||
```bash
|
||||
colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL
|
||||
```
|
|
@ -0,0 +1,179 @@
|
|||
import argparse
|
||||
import json
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
_DTYPE_MAPPING = {
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def log_generation_time(log_data, log_file):
|
||||
with open(log_file, "a") as f:
|
||||
json.dump(log_data, f, indent=2)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def warmup(engine, args):
|
||||
for _ in range(args.n_warm_up_steps):
|
||||
engine.generate(
|
||||
prompts=["hello world"],
|
||||
generation_config=DiffusionGenerationConfig(
|
||||
num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def profile_context(args):
|
||||
return (
|
||||
torch.profiler.profile(
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
if args.profile
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
|
||||
def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None):
|
||||
log_data = {
|
||||
"mode": mode,
|
||||
"model": model_name,
|
||||
"batch_size": args.batch_size,
|
||||
"patched_parallel_size": args.patched_parallel_size,
|
||||
"num_inference_steps": args.num_inference_steps,
|
||||
"height": h,
|
||||
"width": w,
|
||||
"dtype": args.dtype,
|
||||
"profile": args.profile,
|
||||
"n_warm_up_steps": args.n_warm_up_steps,
|
||||
"n_repeat_times": args.n_repeat_times,
|
||||
"avg_generation_time": avg_time,
|
||||
"log_message": log_msg,
|
||||
}
|
||||
|
||||
if args.log:
|
||||
log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json"
|
||||
log_generation_time(log_data=log_data, log_file=log_file)
|
||||
|
||||
if args.profile:
|
||||
file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json"
|
||||
prof.export_chrome_trace(file)
|
||||
|
||||
|
||||
def benchmark_colossalai(rank, world_size, port, args):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
from colossalai.cluster.dist_coordinator import DistCoordinator
|
||||
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
inference_config = InferenceConfig(
|
||||
dtype=args.dtype,
|
||||
patched_parallelism_size=args.patched_parallel_size,
|
||||
)
|
||||
engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False)
|
||||
|
||||
warmup(engine, args)
|
||||
|
||||
for h, w in zip(args.height, args.width):
|
||||
with profile_context(args) as prof:
|
||||
start = time.perf_counter()
|
||||
for _ in range(args.n_repeat_times):
|
||||
engine.generate(
|
||||
prompts=["hello world"],
|
||||
generation_config=DiffusionGenerationConfig(
|
||||
num_inference_steps=args.num_inference_steps, height=h, width=w
|
||||
),
|
||||
)
|
||||
end = time.perf_counter()
|
||||
|
||||
avg_time = (end - start) / args.n_repeat_times
|
||||
log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
|
||||
coordinator.print_on_master(log_msg)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof)
|
||||
|
||||
|
||||
def benchmark_diffusers(args):
|
||||
model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda")
|
||||
|
||||
for _ in range(args.n_warm_up_steps):
|
||||
model(
|
||||
prompt="hello world",
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
height=args.height[0],
|
||||
width=args.width[0],
|
||||
)
|
||||
|
||||
for h, w in zip(args.height, args.width):
|
||||
with profile_context(args) as prof:
|
||||
start = time.perf_counter()
|
||||
for _ in range(args.n_repeat_times):
|
||||
model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w)
|
||||
end = time.perf_counter()
|
||||
|
||||
avg_time = (end - start) / args.n_repeat_times
|
||||
log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
|
||||
print(log_msg)
|
||||
|
||||
log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def benchmark(args):
|
||||
if args.mode == "colossalai":
|
||||
spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args)
|
||||
elif args.mode == "diffusers":
|
||||
benchmark_diffusers(args)
|
||||
|
||||
|
||||
"""
|
||||
# enable log
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log
|
||||
|
||||
# enable profiler
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
|
||||
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
|
||||
parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size")
|
||||
parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps")
|
||||
parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list")
|
||||
parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list")
|
||||
parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
|
||||
parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps")
|
||||
parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times")
|
||||
parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler")
|
||||
parser.add_argument("--log", default=False, action="store_true", help="Enable logging")
|
||||
parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path")
|
||||
parser.add_argument(
|
||||
"--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
benchmark(args)
|
|
@ -0,0 +1,80 @@
|
|||
# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from cleanfid import fid
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
|
||||
from torchvision.transforms import Resize
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def read_image(path: str):
|
||||
"""
|
||||
input: path
|
||||
output: tensor (C, H, W)
|
||||
"""
|
||||
img = np.asarray(Image.open(path))
|
||||
if len(img.shape) == 2:
|
||||
img = np.repeat(img[:, :, None], 3, axis=2)
|
||||
img = torch.from_numpy(img).permute(2, 0, 1)
|
||||
return img
|
||||
|
||||
|
||||
class MultiImageDataset(Dataset):
|
||||
def __init__(self, root0, root1, is_gt=False):
|
||||
super().__init__()
|
||||
self.root0 = root0
|
||||
self.root1 = root1
|
||||
file_names0 = os.listdir(root0)
|
||||
file_names1 = os.listdir(root1)
|
||||
|
||||
self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
|
||||
self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
|
||||
self.is_gt = is_gt
|
||||
assert len(self.image_names0) == len(self.image_names1)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_names0)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
|
||||
if self.is_gt:
|
||||
# resize to 1024 x 1024
|
||||
img0 = Resize((1024, 1024))(img0)
|
||||
img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
|
||||
|
||||
batch_list = [img0, img1]
|
||||
return batch_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--is_gt", action="store_true")
|
||||
parser.add_argument("--input_root0", type=str, required=True)
|
||||
parser.add_argument("--input_root1", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
|
||||
|
||||
dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
|
||||
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
||||
|
||||
progress_bar = tqdm(dataloader)
|
||||
with torch.inference_mode():
|
||||
for i, batch in enumerate(progress_bar):
|
||||
batch = [img.to("cuda") / 255 for img in batch]
|
||||
batch_size = batch[0].shape[0]
|
||||
psnr.update(batch[0], batch[1])
|
||||
lpips.update(batch[0], batch[1])
|
||||
fid_score = fid.compute_fid(args.input_root0, args.input_root1)
|
||||
|
||||
print("PSNR:", psnr.compute().item())
|
||||
print("LPIPS:", lpips.compute().item())
|
||||
print("FID:", fid_score)
|
|
@ -0,0 +1,3 @@
|
|||
torchvision
|
||||
torchmetrics
|
||||
cleanfid
|
|
@ -0,0 +1,42 @@
|
|||
#!/bin/bash
|
||||
|
||||
models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers")
|
||||
parallelism=(1 2 4 8)
|
||||
resolutions=(1024 2048 3840)
|
||||
modes=("colossalai" "diffusers")
|
||||
|
||||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
for p in "${parallelism[@]}"; do
|
||||
for resolution in "${resolutions[@]}"; do
|
||||
for mode in "${modes[@]}"; do
|
||||
if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then
|
||||
continue
|
||||
fi
|
||||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p
|
||||
|
||||
cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution"
|
||||
|
||||
echo "Executing: $cmd"
|
||||
eval $cmd
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
|
@ -1,18 +1,17 @@
|
|||
import argparse
|
||||
|
||||
from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
|
||||
from torch import bfloat16, float16, float32
|
||||
from diffusers import DiffusionPipeline
|
||||
from torch import bfloat16
|
||||
from torch import distributed as dist
|
||||
from torch import float16, float32
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
|
||||
from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
|
||||
|
||||
# For Stable Diffusion 3, we'll use the following configuration
|
||||
MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
|
||||
POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
|
||||
MODEL_CLS = DiffusionPipeline
|
||||
|
||||
TORCH_DTYPE_MAP = {
|
||||
"fp16": float16,
|
||||
|
@ -43,20 +42,27 @@ def infer(args):
|
|||
max_batch_size=args.max_batch_size,
|
||||
tp_size=args.tp_size,
|
||||
use_cuda_kernel=args.use_cuda_kernel,
|
||||
patched_parallelism_size=dist.get_world_size(),
|
||||
)
|
||||
engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
|
||||
engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
|
||||
|
||||
# ==============================
|
||||
# Generation
|
||||
# ==============================
|
||||
coordinator.print_on_master(f"Generating...")
|
||||
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
|
||||
out.save("cat.jpg")
|
||||
if dist.get_rank() == 0:
|
||||
out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
|
||||
coordinator.print_on_master(out)
|
||||
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
||||
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
||||
|
||||
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
|
||||
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue