You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/inference/modeling/layers/distrifusion.py

627 lines
25 KiB

# Code refer and adapted from:
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers
# https://github.com/PipeFusion/PipeFusion
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from diffusers.models import attention_processor
from diffusers.models.attention import Attention
from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from torch import nn
from torch.distributed import ProcessGroup
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.logging import get_dist_logger
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.utils import get_current_device
try:
from flash_attn import flash_attn_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
logger = get_dist_logger(__name__)
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py
def PixArtAlphaTransformer2DModel_forward(
self: PixArtTransformer2DModel,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
assert hasattr(
self, "patched_parallel_size"
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size = hidden_states.shape[0]
height, width = (
hidden_states.shape[-2] // self.config.patch_size,
hidden_states.shape[-1] // self.config.patch_size,
)
hidden_states = self.pos_embed(hidden_states)
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if self.caption_projection is not None:
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(
2, dim=1
)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(
-1,
height // self.patched_parallel_size,
width,
self.config.patch_size,
self.config.patch_size,
self.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(
-1,
self.out_channels,
height // self.patched_parallel_size * self.config.patch_size,
width * self.config.patch_size,
)
)
# enable Distrifusion Optimization
if hasattr(self, "patched_parallel_size"):
from torch import distributed as dist
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
self.output_buffer = torch.empty_like(output)
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
output = output.contiguous()
dist.all_gather(self.buffer_list, output, async_op=False)
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
output = self.output_buffer
return (output,)
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py
def SD3Transformer2DModel_forward(
self: SD3Transformer2DModel,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
pooled_projections: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor]:
assert hasattr(
self, "patched_parallel_size"
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size // self.patched_parallel_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
# enable Distrifusion Optimization
if hasattr(self, "patched_parallel_size"):
from torch import distributed as dist
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
self.output_buffer = torch.empty_like(output)
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
output = output.contiguous()
dist.all_gather(self.buffer_list, output, async_op=False)
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
output = self.output_buffer
return (output,)
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py
class DistrifusionPatchEmbed(ParallelModule):
def __init__(
self,
module: PatchEmbed,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.module = module
self.rank = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
@staticmethod
def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
distrifusion_embed = DistrifusionPatchEmbed(
module, process_group, model_shard_infer_config=model_shard_infer_config
)
return distrifusion_embed
def forward(self, latent):
module = self.module
if module.pos_embed_max_size is not None:
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size
latent = module.proj(latent)
if module.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if module.layer_norm:
latent = module.norm(latent)
if module.pos_embed is None:
return latent.to(latent.dtype)
# Interpolate or crop positional embeddings as needed
if module.pos_embed_max_size:
pos_embed = module.cropped_pos_embed(height, width)
else:
if module.height != height or module.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=module.pos_embed.shape[-1],
grid_size=(height, width),
base_size=module.base_size,
interpolation_scale=module.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
else:
pos_embed = module.pos_embed
b, c, h = pos_embed.shape
pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank]
return (latent + pos_embed).to(latent.dtype)
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py
class DistrifusionConv2D(ParallelModule):
def __init__(
self,
module: nn.Conv2d,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.module = module
self.rank = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
@staticmethod
def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config)
return distrifusion_conv
def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
stride = self.module.stride[0]
padding = self.module.padding[0]
output_h = x.shape[2] // stride // self.patched_parallelism_size
idx = dist.get_rank()
h_begin = output_h * idx * stride - padding
h_end = output_h * (idx + 1) * stride + padding
final_padding = [padding, padding, 0, 0]
if h_begin < 0:
h_begin = 0
final_padding[2] = padding
if h_end > h:
h_end = h
final_padding[3] = padding
sliced_input = x[:, :, h_begin:h_end, :]
padded_input = F.pad(sliced_input, final_padding, mode="constant")
return F.conv2d(
padded_input,
self.module.weight,
self.module.bias,
stride=stride,
padding="valid",
)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
output = self.sliced_forward(input)
return output
# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py
class DistrifusionFusedAttention(ParallelModule):
def __init__(
self,
module: attention_processor.Attention,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.counter = 0
self.module = module
self.buffer_list = None
self.kv_buffer_idx = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
self.handle = None
self.process_group = process_group
self.warm_step = 5 # for warmup
@staticmethod
def from_native_module(
module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
return DistrifusionFusedAttention(
module=module,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
)
def _forward(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
if self.patched_parallelism_size == 1:
full_kv = kv
else:
if self.buffer_list is None: # buffer not created
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
elif self.counter <= self.warm_step:
# logger.info(f"warmup: {self.counter}")
dist.all_gather(
self.buffer_list,
kv,
group=self.process_group,
async_op=False,
)
full_kv = torch.cat(self.buffer_list, dim=1)
else:
# logger.info(f"use old kv to infer: {self.counter}")
self.buffer_list[self.kv_buffer_idx].copy_(kv)
full_kv = torch.cat(self.buffer_list, dim=1)
assert self.handle is None, "we should maintain the kv of last step"
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
if self.handle is not None:
self.handle.wait()
self.handle = None
b, l, c = hidden_states.shape
kv_shape = (b, l, self.module.to_k.out_features * 2)
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
self.buffer_list = [
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
for _ in range(self.patched_parallelism_size)
]
self.counter = 0
attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
output = self._forward(
self.module,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
self.counter += 1
return output
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py
class DistriSelfAttention(ParallelModule):
def __init__(
self,
module: Attention,
process_group: Union[ProcessGroup, List[ProcessGroup]],
model_shard_infer_config: ModelShardInferenceConfig = None,
):
super().__init__()
self.counter = 0
self.module = module
self.buffer_list = None
self.kv_buffer_idx = dist.get_rank(group=process_group)
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
self.handle = None
self.process_group = process_group
self.warm_step = 3 # for warmup
@staticmethod
def from_native_module(
module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
return DistriSelfAttention(
module=module,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
)
def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
attn = self.module
assert isinstance(attn, Attention)
residual = hidden_states
batch_size, sequence_length, _ = hidden_states.shape
query = attn.to_q(hidden_states)
encoder_hidden_states = hidden_states
k = self.module.to_k(encoder_hidden_states)
v = self.module.to_v(encoder_hidden_states)
kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
if self.patched_parallelism_size == 1:
full_kv = kv
else:
if self.buffer_list is None: # buffer not created
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
elif self.counter <= self.warm_step:
# logger.info(f"warmup: {self.counter}")
dist.all_gather(
self.buffer_list,
kv,
group=self.process_group,
async_op=False,
)
full_kv = torch.cat(self.buffer_list, dim=1)
else:
# logger.info(f"use old kv to infer: {self.counter}")
self.buffer_list[self.kv_buffer_idx].copy_(kv)
full_kv = torch.cat(self.buffer_list, dim=1)
assert self.handle is None, "we should maintain the kv of last step"
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
if HAS_FLASH_ATTN:
# flash attn
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
else:
# naive attn
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
# async preallocates memo buffer
if self.handle is not None:
self.handle.wait()
self.handle = None
b, l, c = hidden_states.shape
kv_shape = (b, l, self.module.to_k.out_features * 2)
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
self.buffer_list = [
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
for _ in range(self.patched_parallelism_size)
]
self.counter = 0
output = self._forward(hidden_states, scale=scale)
self.counter += 1
return output