[Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement
pull/5951/head
Runyu Lu 2024-07-30 10:43:26 +08:00 committed by GitHub
parent 7b38964e3a
commit bcf0181ecd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1089 additions and 16 deletions

View File

@ -18,7 +18,7 @@
## 📌 Introduction
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [[blog]](https://hpc-ai.com/blog/colossal-inference)
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/colossal-inference-v1-1.png" width=1000/>
@ -310,4 +310,14 @@ If you wish to cite relevant research papars, you can find the reference below.
journal={arXiv},
year={2023}
}
# Distrifusion
@InProceedings{Li_2024_CVPR,
author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month={June},
year={2024},
pages={7183-7193}
}
```

View File

@ -186,6 +186,7 @@ class InferenceConfig(RPC_PARAM):
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
patched_parallelism_size(int): Patched Parallelism Size, When using Distrifusion
"""
# NOTE: arrange configs according to their importance and frequency of usage
@ -245,6 +246,11 @@ class InferenceConfig(RPC_PARAM):
start_token_size: int = 4
generated_token_size: int = 512
# Acceleration for Diffusion Model(PipeFusion or Distrifusion)
patched_parallelism_size: int = 1 # for distrifusion
# pipeFusion_m_size: int = 1 # for pipefusion
# pipeFusion_n_size: int = 1 # for pipefusion
def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
self._verify_config()
@ -288,6 +294,14 @@ class InferenceConfig(RPC_PARAM):
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size
# check Distrifusion
# TODO(@lry89757) need more detailed check
if self.patched_parallelism_size > 1:
# self.use_patched_parallelism = True
self.tp_size = (
self.patched_parallelism_size
) # this is not a real tp, because some annoying check, so we have to set this to patched_parallelism_size
# check prompt template
if self.prompt_template is None:
return
@ -324,6 +338,7 @@ class InferenceConfig(RPC_PARAM):
use_cuda_kernel=self.use_cuda_kernel,
use_spec_dec=self.use_spec_dec,
use_flash_attn=use_flash_attn,
patched_parallelism_size=self.patched_parallelism_size,
)
return model_inference_config
@ -396,6 +411,7 @@ class ModelShardInferenceConfig:
use_cuda_kernel: bool = False
use_spec_dec: bool = False
use_flash_attn: bool = False
patched_parallelism_size: int = 1 # for diffusion model, Distrifusion Technique
@dataclass

View File

@ -11,7 +11,7 @@ from torch import distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig, ModelShardInferenceConfig
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import DiffusionSequence
from colossalai.inference.utils import get_model_size, get_model_type

View File

@ -0,0 +1,626 @@
# Code refer and adapted from:
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers
# https://github.com/PipeFusion/PipeFusion
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from diffusers.models import attention_processor
from diffusers.models.attention import Attention
from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from torch import nn
from torch.distributed import ProcessGroup
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.utils import get_current_device
try:
from flash_attn import flash_attn_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
logger = get_dist_logger(__name__)
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py
def PixArtAlphaTransformer2DModel_forward(
self: PixArtTransformer2DModel,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
assert hasattr(
self, "patched_parallel_size"
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size = hidden_states.shape[0]
height, width = (
hidden_states.shape[-2] // self.config.patch_size,
hidden_states.shape[-1] // self.config.patch_size,
)
hidden_states = self.pos_embed(hidden_states)
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if self.caption_projection is not None:
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(
2, dim=1
)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(
-1,
height // self.patched_parallel_size,
width,
self.config.patch_size,
self.config.patch_size,
self.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(
-1,
self.out_channels,
height // self.patched_parallel_size * self.config.patch_size,
width * self.config.patch_size,
)
)
# enable Distrifusion Optimization
if hasattr(self, "patched_parallel_size"):
from torch import distributed as dist
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
self.output_buffer = torch.empty_like(output)
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
output = output.contiguous()
dist.all_gather(self.buffer_list, output, async_op=False)
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
output = self.output_buffer
return (output,)
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py
def SD3Transformer2DModel_forward(
self: SD3Transformer2DModel,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor]:
assert hasattr(
self, "patched_parallel_size"
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size // self.patched_parallel_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
# enable Distrifusion Optimization
if hasattr(self, "patched_parallel_size"):
from torch import distributed as dist
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
self.output_buffer = torch.empty_like(output)
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
output = output.contiguous()
dist.all_gather(self.buffer_list, output, async_op=False)
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
output = self.output_buffer
return (output,)
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py
class DistrifusionPatchEmbed(ParallelModule):
def __init__(
self,
module: PatchEmbed,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.module = module
self.rank = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
@staticmethod
def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
distrifusion_embed = DistrifusionPatchEmbed(
module, process_group, model_shard_infer_config=model_shard_infer_config
)
return distrifusion_embed
def forward(self, latent):
module = self.module
if module.pos_embed_max_size is not None:
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size
latent = module.proj(latent)
if module.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if module.layer_norm:
latent = module.norm(latent)
if module.pos_embed is None:
return latent.to(latent.dtype)
# Interpolate or crop positional embeddings as needed
if module.pos_embed_max_size:
pos_embed = module.cropped_pos_embed(height, width)
else:
if module.height != height or module.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=module.pos_embed.shape[-1],
grid_size=(height, width),
base_size=module.base_size,
interpolation_scale=module.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
else:
pos_embed = module.pos_embed
b, c, h = pos_embed.shape
pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank]
return (latent + pos_embed).to(latent.dtype)
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py
class DistrifusionConv2D(ParallelModule):
def __init__(
self,
module: nn.Conv2d,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.module = module
self.rank = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
@staticmethod
def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config)
return distrifusion_conv
def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
stride = self.module.stride[0]
padding = self.module.padding[0]
output_h = x.shape[2] // stride // self.patched_parallelism_size
idx = dist.get_rank()
h_begin = output_h * idx * stride - padding
h_end = output_h * (idx + 1) * stride + padding
final_padding = [padding, padding, 0, 0]
if h_begin < 0:
h_begin = 0
final_padding[2] = padding
if h_end > h:
h_end = h
final_padding[3] = padding
sliced_input = x[:, :, h_begin:h_end, :]
padded_input = F.pad(sliced_input, final_padding, mode="constant")
return F.conv2d(
padded_input,
self.module.weight,
self.module.bias,
stride=stride,
padding="valid",
)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
output = self.sliced_forward(input)
return output
# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py
class DistrifusionFusedAttention(ParallelModule):
def __init__(
self,
module: attention_processor.Attention,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.counter = 0
self.module = module
self.buffer_list = None
self.kv_buffer_idx = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
self.handle = None
self.process_group = process_group
self.warm_step = 5 # for warmup
@staticmethod
def from_native_module(
module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
return DistrifusionFusedAttention(
module=module,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
)
def _forward(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
if self.patched_parallelism_size == 1:
full_kv = kv
else:
if self.buffer_list is None: # buffer not created
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
elif self.counter <= self.warm_step:
# logger.info(f"warmup: {self.counter}")
dist.all_gather(
self.buffer_list,
kv,
group=self.process_group,
async_op=False,
)
full_kv = torch.cat(self.buffer_list, dim=1)
else:
# logger.info(f"use old kv to infer: {self.counter}")
self.buffer_list[self.kv_buffer_idx].copy_(kv)
full_kv = torch.cat(self.buffer_list, dim=1)
assert self.handle is None, "we should maintain the kv of last step"
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
if self.handle is not None:
self.handle.wait()
self.handle = None
b, l, c = hidden_states.shape
kv_shape = (b, l, self.module.to_k.out_features * 2)
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
self.buffer_list = [
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
for _ in range(self.patched_parallelism_size)
]
self.counter = 0
attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
output = self._forward(
self.module,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
self.counter += 1
return output
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py
class DistriSelfAttention(ParallelModule):
def __init__(
self,
module: Attention,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.counter = 0
self.module = module
self.buffer_list = None
self.kv_buffer_idx = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
self.handle = None
self.process_group = process_group
self.warm_step = 3 # for warmup
@staticmethod
def from_native_module(
module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
return DistriSelfAttention(
module=module,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
)
def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
attn = self.module
assert isinstance(attn, Attention)
residual = hidden_states
batch_size, sequence_length, _ = hidden_states.shape
query = attn.to_q(hidden_states)
encoder_hidden_states = hidden_states
k = self.module.to_k(encoder_hidden_states)
v = self.module.to_v(encoder_hidden_states)
kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
if self.patched_parallelism_size == 1:
full_kv = kv
else:
if self.buffer_list is None: # buffer not created
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
elif self.counter <= self.warm_step:
# logger.info(f"warmup: {self.counter}")
dist.all_gather(
self.buffer_list,
kv,
group=self.process_group,
async_op=False,
)
full_kv = torch.cat(self.buffer_list, dim=1)
else:
# logger.info(f"use old kv to infer: {self.counter}")
self.buffer_list[self.kv_buffer_idx].copy_(kv)
full_kv = torch.cat(self.buffer_list, dim=1)
assert self.handle is None, "we should maintain the kv of last step"
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
if HAS_FLASH_ATTN:
# flash attn
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
else:
# naive attn
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
# async preallocates memo buffer
if self.handle is not None:
self.handle.wait()
self.handle = None
b, l, c = hidden_states.shape
kv_shape = (b, l, self.module.to_k.out_features * 2)
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
self.buffer_list = [
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
for _ in range(self.patched_parallelism_size)
]
self.counter = 0
output = self._forward(hidden_states, scale=scale)
self.counter += 1
return output

View File

@ -14,7 +14,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retri
from colossalai.logging import get_dist_logger
from .diffusion import DiffusionPipe
from ..layers.diffusion import DiffusionPipe
logger = get_dist_logger(__name__)

View File

@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
from .diffusion import DiffusionPipe
from ..layers.diffusion import DiffusionPipe
# TODO(@lry89757) temporarily image, please support more return output

View File

@ -1,9 +1,17 @@
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
from torch import nn
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
from colossalai.inference.modeling.layers.distrifusion import (
DistrifusionConv2D,
DistrifusionPatchEmbed,
DistriSelfAttention,
PixArtAlphaTransformer2DModel_forward,
)
from colossalai.inference.modeling.models.pixart_alpha import pixart_alpha_forward
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
@ -12,9 +20,46 @@ class PixArtAlphaInferPolicy(Policy, RPC_PARAM):
def module_policy(self):
policy = {}
if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
policy[PixArtTransformer2DModel] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="pos_embed.proj",
target_module=DistrifusionConv2D,
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
),
SubModuleReplacementDescription(
suffix="pos_embed",
target_module=DistrifusionPatchEmbed,
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
),
],
attribute_replacement={
"patched_parallel_size": self.shard_config.extra_kwargs[
"model_shard_infer_config"
].patched_parallelism_size
},
method_replacement={"forward": PixArtAlphaTransformer2DModel_forward},
)
policy[BasicTransformerBlock] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn1",
target_module=DistriSelfAttention,
kwargs={
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
},
)
]
)
self.append_or_create_method_replacement(
description={"forward": pixart_alpha_forward}, policy=policy, target_key=DiffusionPipe
)
return policy
def preprocess(self) -> nn.Module:

View File

@ -1,9 +1,17 @@
from diffusers.models.attention import JointTransformerBlock
from diffusers.models.transformers import SD3Transformer2DModel
from torch import nn
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.diffusion import DiffusionPipe
from colossalai.inference.modeling.layers.diffusion import DiffusionPipe
from colossalai.inference.modeling.layers.distrifusion import (
DistrifusionConv2D,
DistrifusionFusedAttention,
DistrifusionPatchEmbed,
SD3Transformer2DModel_forward,
)
from colossalai.inference.modeling.models.stablediffusion3 import sd3_forward
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
@ -12,6 +20,42 @@ class StableDiffusion3InferPolicy(Policy, RPC_PARAM):
def module_policy(self):
policy = {}
if self.shard_config.extra_kwargs["model_shard_infer_config"].patched_parallelism_size > 1:
policy[SD3Transformer2DModel] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="pos_embed.proj",
target_module=DistrifusionConv2D,
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
),
SubModuleReplacementDescription(
suffix="pos_embed",
target_module=DistrifusionPatchEmbed,
kwargs={"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"]},
),
],
attribute_replacement={
"patched_parallel_size": self.shard_config.extra_kwargs[
"model_shard_infer_config"
].patched_parallelism_size
},
method_replacement={"forward": SD3Transformer2DModel_forward},
)
policy[JointTransformerBlock] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn",
target_module=DistrifusionFusedAttention,
kwargs={
"model_shard_infer_config": self.shard_config.extra_kwargs["model_shard_infer_config"],
},
)
]
)
self.append_or_create_method_replacement(
description={"forward": sd3_forward}, policy=policy, target_key=DiffusionPipe
)

View File

@ -0,0 +1,22 @@
## File Structure
```
|- sd3_generation.py: an example of how to use Colossalai Inference Engine to generate result by loading Diffusion Model.
|- compute_metric.py: compare the quality of images w/o some acceleration method like Distrifusion
|- benchmark_sd3.py: benchmark the performance of our InferenceEngine
|- run_benchmark.sh: run benchmark command
```
Note: compute_metric.py need some dependencies which need `pip install -r requirements.txt`, `requirements.txt` is in `examples/inference/stable_diffusion/`
## Run Inference
The provided example `sd3_generation.py` is an example to configure, initialize the engine, and run inference on provided model. We've added `DiffusionPipeline` as model class, and the script is good to run inference with StableDiffusion 3.
For a basic setting, you could run the example by:
```bash
colossalai run --nproc_per_node 1 sd3_generation.py -m PATH_MODEL -p "hello world"
```
Run multi-GPU inference (Patched Parallelism), as in the following example using 2 GPUs:
```bash
colossalai run --nproc_per_node 2 sd3_generation.py -m PATH_MODEL
```

View File

@ -0,0 +1,179 @@
import argparse
import json
import time
from contextlib import nullcontext
import torch
import torch.distributed as dist
from diffusers import DiffusionPipeline
import colossalai
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024
_DTYPE_MAPPING = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
def log_generation_time(log_data, log_file):
with open(log_file, "a") as f:
json.dump(log_data, f, indent=2)
f.write("\n")
def warmup(engine, args):
for _ in range(args.n_warm_up_steps):
engine.generate(
prompts=["hello world"],
generation_config=DiffusionGenerationConfig(
num_inference_steps=args.num_inference_steps, height=args.height[0], width=args.width[0]
),
)
def profile_context(args):
return (
torch.profiler.profile(
record_shapes=True,
with_stack=True,
with_modules=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
)
if args.profile
else nullcontext()
)
def log_and_profile(h, w, avg_time, log_msg, args, model_name, mode, prof=None):
log_data = {
"mode": mode,
"model": model_name,
"batch_size": args.batch_size,
"patched_parallel_size": args.patched_parallel_size,
"num_inference_steps": args.num_inference_steps,
"height": h,
"width": w,
"dtype": args.dtype,
"profile": args.profile,
"n_warm_up_steps": args.n_warm_up_steps,
"n_repeat_times": args.n_repeat_times,
"avg_generation_time": avg_time,
"log_message": log_msg,
}
if args.log:
log_file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}.json"
log_generation_time(log_data=log_data, log_file=log_file)
if args.profile:
file = f"examples/inference/stable_diffusion/benchmark_{model_name}_{mode}_prof.json"
prof.export_chrome_trace(file)
def benchmark_colossalai(rank, world_size, port, args):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
from colossalai.cluster.dist_coordinator import DistCoordinator
coordinator = DistCoordinator()
inference_config = InferenceConfig(
dtype=args.dtype,
patched_parallelism_size=args.patched_parallel_size,
)
engine = InferenceEngine(args.model, inference_config=inference_config, verbose=False)
warmup(engine, args)
for h, w in zip(args.height, args.width):
with profile_context(args) as prof:
start = time.perf_counter()
for _ in range(args.n_repeat_times):
engine.generate(
prompts=["hello world"],
generation_config=DiffusionGenerationConfig(
num_inference_steps=args.num_inference_steps, height=h, width=w
),
)
end = time.perf_counter()
avg_time = (end - start) / args.n_repeat_times
log_msg = f"[ColossalAI]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
coordinator.print_on_master(log_msg)
if dist.get_rank() == 0:
log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "colossalai", prof=prof)
def benchmark_diffusers(args):
model = DiffusionPipeline.from_pretrained(args.model, torch_dtype=_DTYPE_MAPPING[args.dtype]).to("cuda")
for _ in range(args.n_warm_up_steps):
model(
prompt="hello world",
num_inference_steps=args.num_inference_steps,
height=args.height[0],
width=args.width[0],
)
for h, w in zip(args.height, args.width):
with profile_context(args) as prof:
start = time.perf_counter()
for _ in range(args.n_repeat_times):
model(prompt="hello world", num_inference_steps=args.num_inference_steps, height=h, width=w)
end = time.perf_counter()
avg_time = (end - start) / args.n_repeat_times
log_msg = f"[Diffusers]avg generation time for h({h})xw({w}) is {avg_time:.2f}s"
print(log_msg)
log_and_profile(h, w, avg_time, log_msg, args, args.model.split("/")[-1], "diffusers", prof)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def benchmark(args):
if args.mode == "colossalai":
spawn(benchmark_colossalai, nprocs=args.patched_parallel_size, args=args)
elif args.mode == "diffusers":
benchmark_diffusers(args)
"""
# enable log
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --log
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --log
# enable profiler
python examples/inference/stable_diffusion/benchmark_sd3.py -m "stabilityai/stable-diffusion-3-medium-diffusers" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" -p 2 --mode colossalai --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
python examples/inference/stable_diffusion/benchmark_sd3.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --mode diffusers --n_warm_up_steps 3 --n_repeat_times 1 --profile --num_inference_steps 20
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("-p", "--patched_parallel_size", type=int, default=1, help="Patched Parallelism size")
parser.add_argument("-n", "--num_inference_steps", type=int, default=50, help="Number of inference steps")
parser.add_argument("-H", "--height", type=int, nargs="+", default=[1024, 2048], help="Height list")
parser.add_argument("-w", "--width", type=int, nargs="+", default=[1024, 2048], help="Width list")
parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
parser.add_argument("--n_warm_up_steps", type=int, default=3, help="Number of warm up steps")
parser.add_argument("--n_repeat_times", type=int, default=5, help="Number of repeat times")
parser.add_argument("--profile", default=False, action="store_true", help="Enable torch profiler")
parser.add_argument("--log", default=False, action="store_true", help="Enable logging")
parser.add_argument("-m", "--model", default="stabilityai/stable-diffusion-3-medium-diffusers", help="Model path")
parser.add_argument(
"--mode", default="colossalai", choices=["colossalai", "diffusers"], help="Inference framework mode"
)
args = parser.parse_args()
benchmark(args)

View File

@ -0,0 +1,80 @@
# Code from https://github.com/mit-han-lab/distrifuser/blob/main/scripts/compute_metrics.py
import argparse
import os
import numpy as np
import torch
from cleanfid import fid
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity, PeakSignalNoiseRatio
from torchvision.transforms import Resize
from tqdm import tqdm
def read_image(path: str):
"""
input: path
output: tensor (C, H, W)
"""
img = np.asarray(Image.open(path))
if len(img.shape) == 2:
img = np.repeat(img[:, :, None], 3, axis=2)
img = torch.from_numpy(img).permute(2, 0, 1)
return img
class MultiImageDataset(Dataset):
def __init__(self, root0, root1, is_gt=False):
super().__init__()
self.root0 = root0
self.root1 = root1
file_names0 = os.listdir(root0)
file_names1 = os.listdir(root1)
self.image_names0 = sorted([name for name in file_names0 if name.endswith(".png") or name.endswith(".jpg")])
self.image_names1 = sorted([name for name in file_names1 if name.endswith(".png") or name.endswith(".jpg")])
self.is_gt = is_gt
assert len(self.image_names0) == len(self.image_names1)
def __len__(self):
return len(self.image_names0)
def __getitem__(self, idx):
img0 = read_image(os.path.join(self.root0, self.image_names0[idx]))
if self.is_gt:
# resize to 1024 x 1024
img0 = Resize((1024, 1024))(img0)
img1 = read_image(os.path.join(self.root1, self.image_names1[idx]))
batch_list = [img0, img1]
return batch_list
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--is_gt", action="store_true")
parser.add_argument("--input_root0", type=str, required=True)
parser.add_argument("--input_root1", type=str, required=True)
args = parser.parse_args()
psnr = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to("cuda")
lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to("cuda")
dataset = MultiImageDataset(args.input_root0, args.input_root1, is_gt=args.is_gt)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
progress_bar = tqdm(dataloader)
with torch.inference_mode():
for i, batch in enumerate(progress_bar):
batch = [img.to("cuda") / 255 for img in batch]
batch_size = batch[0].shape[0]
psnr.update(batch[0], batch[1])
lpips.update(batch[0], batch[1])
fid_score = fid.compute_fid(args.input_root0, args.input_root1)
print("PSNR:", psnr.compute().item())
print("LPIPS:", lpips.compute().item())
print("FID:", fid_score)

View File

@ -0,0 +1,3 @@
torchvision
torchmetrics
cleanfid

View File

@ -0,0 +1,42 @@
#!/bin/bash
models=("PixArt-alpha/PixArt-XL-2-1024-MS" "stabilityai/stable-diffusion-3-medium-diffusers")
parallelism=(1 2 4 8)
resolutions=(1024 2048 3840)
modes=("colossalai" "diffusers")
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
for model in "${models[@]}"; do
for p in "${parallelism[@]}"; do
for resolution in "${resolutions[@]}"; do
for mode in "${modes[@]}"; do
if [[ "$mode" == "colossalai" && "$p" == 1 ]]; then
continue
fi
if [[ "$mode" == "diffusers" && "$p" != 1 ]]; then
continue
fi
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage $p
cmd="python examples/inference/stable_diffusion/benchmark_sd3.py -m \"$model\" -p $p --mode $mode --log -H $resolution -w $resolution"
echo "Executing: $cmd"
eval $cmd
done
done
done
done

View File

@ -1,18 +1,17 @@
import argparse
from diffusers import PixArtAlphaPipeline, StableDiffusion3Pipeline
from torch import bfloat16, float16, float32
from diffusers import DiffusionPipeline
from torch import bfloat16
from torch import distributed as dist
from torch import float16, float32
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.policy.pixart_alpha import PixArtAlphaInferPolicy
from colossalai.inference.modeling.policy.stablediffusion3 import StableDiffusion3InferPolicy
# For Stable Diffusion 3, we'll use the following configuration
MODEL_CLS = [StableDiffusion3Pipeline, PixArtAlphaPipeline][0]
POLICY_CLS = [StableDiffusion3InferPolicy, PixArtAlphaInferPolicy][0]
MODEL_CLS = DiffusionPipeline
TORCH_DTYPE_MAP = {
"fp16": float16,
@ -43,20 +42,27 @@ def infer(args):
max_batch_size=args.max_batch_size,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
patched_parallelism_size=dist.get_world_size(),
)
engine = InferenceEngine(model, inference_config=inference_config, model_policy=POLICY_CLS(), verbose=True)
engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
# ==============================
# Generation
# ==============================
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
out.save("cat.jpg")
if dist.get_rank() == 0:
out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
coordinator.print_on_master(out)
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
if __name__ == "__main__":