mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
627 lines
25 KiB
627 lines
25 KiB
# Code refer and adapted from:
|
|
# https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers
|
|
# https://github.com/PipeFusion/PipeFusion
|
|
|
|
import inspect
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
from diffusers.models import attention_processor
|
|
from diffusers.models.attention import Attention
|
|
from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed
|
|
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
|
|
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
|
from torch import nn
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from colossalai.inference.config import ModelShardInferenceConfig
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
|
from colossalai.utils import get_current_device
|
|
|
|
try:
|
|
from flash_attn import flash_attn_func
|
|
|
|
HAS_FLASH_ATTN = True
|
|
except ImportError:
|
|
HAS_FLASH_ATTN = False
|
|
|
|
|
|
logger = get_dist_logger(__name__)
|
|
|
|
|
|
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_2d.py
|
|
def PixArtAlphaTransformer2DModel_forward(
|
|
self: PixArtTransformer2DModel,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
timestep: Optional[torch.LongTensor] = None,
|
|
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
|
class_labels: Optional[torch.LongTensor] = None,
|
|
cross_attention_kwargs: Dict[str, Any] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
return_dict: bool = True,
|
|
):
|
|
assert hasattr(
|
|
self, "patched_parallel_size"
|
|
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
|
|
|
|
if cross_attention_kwargs is not None:
|
|
if cross_attention_kwargs.get("scale", None) is not None:
|
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
|
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
|
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
|
# expects mask of shape:
|
|
# [batch, key_tokens]
|
|
# adds singleton query_tokens dimension:
|
|
# [batch, 1, key_tokens]
|
|
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
|
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
|
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
|
if attention_mask is not None and attention_mask.ndim == 2:
|
|
# assume that mask is expressed as:
|
|
# (1 = keep, 0 = discard)
|
|
# convert mask into a bias that can be added to attention scores:
|
|
# (keep = +0, discard = -10000.0)
|
|
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
|
attention_mask = attention_mask.unsqueeze(1)
|
|
|
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
|
|
|
# 1. Input
|
|
batch_size = hidden_states.shape[0]
|
|
height, width = (
|
|
hidden_states.shape[-2] // self.config.patch_size,
|
|
hidden_states.shape[-1] // self.config.patch_size,
|
|
)
|
|
hidden_states = self.pos_embed(hidden_states)
|
|
|
|
timestep, embedded_timestep = self.adaln_single(
|
|
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
|
)
|
|
|
|
if self.caption_projection is not None:
|
|
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
|
|
|
# 2. Blocks
|
|
for block in self.transformer_blocks:
|
|
hidden_states = block(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
timestep=timestep,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
class_labels=class_labels,
|
|
)
|
|
|
|
# 3. Output
|
|
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(
|
|
2, dim=1
|
|
)
|
|
hidden_states = self.norm_out(hidden_states)
|
|
# Modulation
|
|
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
|
|
hidden_states = self.proj_out(hidden_states)
|
|
hidden_states = hidden_states.squeeze(1)
|
|
|
|
# unpatchify
|
|
hidden_states = hidden_states.reshape(
|
|
shape=(
|
|
-1,
|
|
height // self.patched_parallel_size,
|
|
width,
|
|
self.config.patch_size,
|
|
self.config.patch_size,
|
|
self.out_channels,
|
|
)
|
|
)
|
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
output = hidden_states.reshape(
|
|
shape=(
|
|
-1,
|
|
self.out_channels,
|
|
height // self.patched_parallel_size * self.config.patch_size,
|
|
width * self.config.patch_size,
|
|
)
|
|
)
|
|
|
|
# enable Distrifusion Optimization
|
|
if hasattr(self, "patched_parallel_size"):
|
|
from torch import distributed as dist
|
|
|
|
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
|
|
self.output_buffer = torch.empty_like(output)
|
|
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
|
|
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
|
|
output = output.contiguous()
|
|
dist.all_gather(self.buffer_list, output, async_op=False)
|
|
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
|
|
output = self.output_buffer
|
|
|
|
return (output,)
|
|
|
|
|
|
# adapted from https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/transformers/transformer_sd3.py
|
|
def SD3Transformer2DModel_forward(
|
|
self: SD3Transformer2DModel,
|
|
hidden_states: torch.FloatTensor,
|
|
encoder_hidden_states: torch.FloatTensor = None,
|
|
pooled_projections: torch.FloatTensor = None,
|
|
timestep: torch.LongTensor = None,
|
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
return_dict: bool = True,
|
|
) -> Union[torch.FloatTensor]:
|
|
|
|
assert hasattr(
|
|
self, "patched_parallel_size"
|
|
), "please check your policy, `Transformer2DModel` Must have attribute `patched_parallel_size`"
|
|
|
|
height, width = hidden_states.shape[-2:]
|
|
|
|
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
|
temb = self.time_text_embed(timestep, pooled_projections)
|
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
|
|
|
for block in self.transformer_blocks:
|
|
encoder_hidden_states, hidden_states = block(
|
|
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
|
)
|
|
|
|
hidden_states = self.norm_out(hidden_states, temb)
|
|
hidden_states = self.proj_out(hidden_states)
|
|
|
|
# unpatchify
|
|
patch_size = self.config.patch_size
|
|
height = height // patch_size // self.patched_parallel_size
|
|
width = width // patch_size
|
|
|
|
hidden_states = hidden_states.reshape(
|
|
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
|
|
)
|
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
output = hidden_states.reshape(
|
|
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
|
)
|
|
|
|
# enable Distrifusion Optimization
|
|
if hasattr(self, "patched_parallel_size"):
|
|
from torch import distributed as dist
|
|
|
|
if (getattr(self, "output_buffer", None) is None) or (self.output_buffer.shape != output.shape):
|
|
self.output_buffer = torch.empty_like(output)
|
|
if (getattr(self, "buffer_list", None) is None) or (self.buffer_list[0].shape != output.shape):
|
|
self.buffer_list = [torch.empty_like(output) for _ in range(self.patched_parallel_size)]
|
|
output = output.contiguous()
|
|
dist.all_gather(self.buffer_list, output, async_op=False)
|
|
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
|
|
output = self.output_buffer
|
|
|
|
return (output,)
|
|
|
|
|
|
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/patchembed.py
|
|
class DistrifusionPatchEmbed(ParallelModule):
|
|
def __init__(
|
|
self,
|
|
module: PatchEmbed,
|
|
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
|
):
|
|
super().__init__()
|
|
self.module = module
|
|
self.rank = dist.get_rank(group=process_group)
|
|
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
|
|
|
@staticmethod
|
|
def from_native_module(module: PatchEmbed, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
|
|
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
|
distrifusion_embed = DistrifusionPatchEmbed(
|
|
module, process_group, model_shard_infer_config=model_shard_infer_config
|
|
)
|
|
return distrifusion_embed
|
|
|
|
def forward(self, latent):
|
|
module = self.module
|
|
if module.pos_embed_max_size is not None:
|
|
height, width = latent.shape[-2:]
|
|
else:
|
|
height, width = latent.shape[-2] // module.patch_size, latent.shape[-1] // module.patch_size
|
|
|
|
latent = module.proj(latent)
|
|
if module.flatten:
|
|
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
|
if module.layer_norm:
|
|
latent = module.norm(latent)
|
|
if module.pos_embed is None:
|
|
return latent.to(latent.dtype)
|
|
# Interpolate or crop positional embeddings as needed
|
|
if module.pos_embed_max_size:
|
|
pos_embed = module.cropped_pos_embed(height, width)
|
|
else:
|
|
if module.height != height or module.width != width:
|
|
pos_embed = get_2d_sincos_pos_embed(
|
|
embed_dim=module.pos_embed.shape[-1],
|
|
grid_size=(height, width),
|
|
base_size=module.base_size,
|
|
interpolation_scale=module.interpolation_scale,
|
|
)
|
|
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
|
|
else:
|
|
pos_embed = module.pos_embed
|
|
|
|
b, c, h = pos_embed.shape
|
|
pos_embed = pos_embed.view(b, self.patched_parallelism_size, -1, h)[:, self.rank]
|
|
|
|
return (latent + pos_embed).to(latent.dtype)
|
|
|
|
|
|
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/conv2d.py
|
|
class DistrifusionConv2D(ParallelModule):
|
|
|
|
def __init__(
|
|
self,
|
|
module: nn.Conv2d,
|
|
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
|
):
|
|
super().__init__()
|
|
self.module = module
|
|
self.rank = dist.get_rank(group=process_group)
|
|
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
|
|
|
@staticmethod
|
|
def from_native_module(module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs):
|
|
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
|
distrifusion_conv = DistrifusionConv2D(module, process_group, model_shard_infer_config=model_shard_infer_config)
|
|
return distrifusion_conv
|
|
|
|
def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
b, c, h, w = x.shape
|
|
|
|
stride = self.module.stride[0]
|
|
padding = self.module.padding[0]
|
|
|
|
output_h = x.shape[2] // stride // self.patched_parallelism_size
|
|
idx = dist.get_rank()
|
|
h_begin = output_h * idx * stride - padding
|
|
h_end = output_h * (idx + 1) * stride + padding
|
|
final_padding = [padding, padding, 0, 0]
|
|
if h_begin < 0:
|
|
h_begin = 0
|
|
final_padding[2] = padding
|
|
if h_end > h:
|
|
h_end = h
|
|
final_padding[3] = padding
|
|
sliced_input = x[:, :, h_begin:h_end, :]
|
|
padded_input = F.pad(sliced_input, final_padding, mode="constant")
|
|
return F.conv2d(
|
|
padded_input,
|
|
self.module.weight,
|
|
self.module.bias,
|
|
stride=stride,
|
|
padding="valid",
|
|
)
|
|
|
|
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
output = self.sliced_forward(input)
|
|
return output
|
|
|
|
|
|
# Code adapted from: https://github.com/huggingface/diffusers/blob/v0.29.0-release/src/diffusers/models/attention_processor.py
|
|
class DistrifusionFusedAttention(ParallelModule):
|
|
|
|
def __init__(
|
|
self,
|
|
module: attention_processor.Attention,
|
|
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
|
):
|
|
super().__init__()
|
|
self.counter = 0
|
|
self.module = module
|
|
self.buffer_list = None
|
|
self.kv_buffer_idx = dist.get_rank(group=process_group)
|
|
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
|
self.handle = None
|
|
self.process_group = process_group
|
|
self.warm_step = 5 # for warmup
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: attention_processor.Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
) -> ParallelModule:
|
|
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
|
return DistrifusionFusedAttention(
|
|
module=module,
|
|
process_group=process_group,
|
|
model_shard_infer_config=model_shard_infer_config,
|
|
)
|
|
|
|
def _forward(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states: torch.FloatTensor,
|
|
encoder_hidden_states: torch.FloatTensor = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.FloatTensor:
|
|
residual = hidden_states
|
|
|
|
input_ndim = hidden_states.ndim
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
context_input_ndim = encoder_hidden_states.ndim
|
|
if context_input_ndim == 4:
|
|
batch_size, channel, height, width = encoder_hidden_states.shape
|
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size = encoder_hidden_states.shape[0]
|
|
|
|
# `sample` projections.
|
|
query = attn.to_q(hidden_states)
|
|
key = attn.to_k(hidden_states)
|
|
value = attn.to_v(hidden_states)
|
|
|
|
kv = torch.cat([key, value], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
|
|
|
|
if self.patched_parallelism_size == 1:
|
|
full_kv = kv
|
|
else:
|
|
if self.buffer_list is None: # buffer not created
|
|
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
|
|
elif self.counter <= self.warm_step:
|
|
# logger.info(f"warmup: {self.counter}")
|
|
dist.all_gather(
|
|
self.buffer_list,
|
|
kv,
|
|
group=self.process_group,
|
|
async_op=False,
|
|
)
|
|
full_kv = torch.cat(self.buffer_list, dim=1)
|
|
else:
|
|
# logger.info(f"use old kv to infer: {self.counter}")
|
|
self.buffer_list[self.kv_buffer_idx].copy_(kv)
|
|
full_kv = torch.cat(self.buffer_list, dim=1)
|
|
assert self.handle is None, "we should maintain the kv of last step"
|
|
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
|
|
|
|
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
|
|
|
|
# `context` projections.
|
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
|
|
|
# attention
|
|
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
|
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
|
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
|
|
|
inner_dim = key.shape[-1]
|
|
head_dim = inner_dim // attn.heads
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
|
query, key, value, dropout_p=0.0, is_causal=False
|
|
) # NOTE(@lry89757) for torch >= 2.2, flash attn has been already integrated into scaled_dot_product_attention, https://pytorch.org/blog/pytorch2-2/
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
|
|
# Split the attention outputs.
|
|
hidden_states, encoder_hidden_states = (
|
|
hidden_states[:, : residual.shape[1]],
|
|
hidden_states[:, residual.shape[1] :],
|
|
)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
if not attn.context_pre_only:
|
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
if context_input_ndim == 4:
|
|
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
return hidden_states, encoder_hidden_states
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
**cross_attention_kwargs,
|
|
) -> torch.Tensor:
|
|
|
|
if self.handle is not None:
|
|
self.handle.wait()
|
|
self.handle = None
|
|
|
|
b, l, c = hidden_states.shape
|
|
kv_shape = (b, l, self.module.to_k.out_features * 2)
|
|
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
|
|
|
|
self.buffer_list = [
|
|
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
|
|
for _ in range(self.patched_parallelism_size)
|
|
]
|
|
|
|
self.counter = 0
|
|
|
|
attn_parameters = set(inspect.signature(self.module.processor.__call__).parameters.keys())
|
|
quiet_attn_parameters = {"ip_adapter_masks"}
|
|
unused_kwargs = [
|
|
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
|
]
|
|
if len(unused_kwargs) > 0:
|
|
logger.warning(
|
|
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.module.processor.__class__.__name__} and will be ignored."
|
|
)
|
|
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
|
|
|
output = self._forward(
|
|
self.module,
|
|
hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=attention_mask,
|
|
**cross_attention_kwargs,
|
|
)
|
|
|
|
self.counter += 1
|
|
|
|
return output
|
|
|
|
|
|
# Code adapted from: https://github.com/PipeFusion/PipeFusion/blob/main/pipefuser/modules/dit/patch_parallel/attn.py
|
|
class DistriSelfAttention(ParallelModule):
|
|
def __init__(
|
|
self,
|
|
module: Attention,
|
|
process_group: Union[ProcessGroup, List[ProcessGroup]],
|
|
model_shard_infer_config: ModelShardInferenceConfig = None,
|
|
):
|
|
super().__init__()
|
|
self.counter = 0
|
|
self.module = module
|
|
self.buffer_list = None
|
|
self.kv_buffer_idx = dist.get_rank(group=process_group)
|
|
self.patched_parallelism_size = model_shard_infer_config.patched_parallelism_size
|
|
self.handle = None
|
|
self.process_group = process_group
|
|
self.warm_step = 3 # for warmup
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: Attention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
) -> ParallelModule:
|
|
model_shard_infer_config = kwargs.get("model_shard_infer_config", None)
|
|
return DistriSelfAttention(
|
|
module=module,
|
|
process_group=process_group,
|
|
model_shard_infer_config=model_shard_infer_config,
|
|
)
|
|
|
|
def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
|
|
attn = self.module
|
|
assert isinstance(attn, Attention)
|
|
|
|
residual = hidden_states
|
|
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
encoder_hidden_states = hidden_states
|
|
k = self.module.to_k(encoder_hidden_states)
|
|
v = self.module.to_v(encoder_hidden_states)
|
|
kv = torch.cat([k, v], dim=-1) # shape of kv now: (bs, seq_len // parallel_size, dim * 2)
|
|
|
|
if self.patched_parallelism_size == 1:
|
|
full_kv = kv
|
|
else:
|
|
if self.buffer_list is None: # buffer not created
|
|
full_kv = torch.cat([kv for _ in range(self.patched_parallelism_size)], dim=1)
|
|
elif self.counter <= self.warm_step:
|
|
# logger.info(f"warmup: {self.counter}")
|
|
dist.all_gather(
|
|
self.buffer_list,
|
|
kv,
|
|
group=self.process_group,
|
|
async_op=False,
|
|
)
|
|
full_kv = torch.cat(self.buffer_list, dim=1)
|
|
else:
|
|
# logger.info(f"use old kv to infer: {self.counter}")
|
|
self.buffer_list[self.kv_buffer_idx].copy_(kv)
|
|
full_kv = torch.cat(self.buffer_list, dim=1)
|
|
assert self.handle is None, "we should maintain the kv of last step"
|
|
self.handle = dist.all_gather(self.buffer_list, kv, group=self.process_group, async_op=True)
|
|
|
|
if HAS_FLASH_ATTN:
|
|
# flash attn
|
|
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
|
|
inner_dim = key.shape[-1]
|
|
head_dim = inner_dim // attn.heads
|
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim)
|
|
key = key.view(batch_size, -1, attn.heads, head_dim)
|
|
value = value.view(batch_size, -1, attn.heads, head_dim)
|
|
|
|
hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
|
|
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
|
|
else:
|
|
# naive attn
|
|
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
|
|
|
|
inner_dim = key.shape[-1]
|
|
head_dim = inner_dim // attn.heads
|
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
scale: float = 1.0,
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.FloatTensor:
|
|
|
|
# async preallocates memo buffer
|
|
if self.handle is not None:
|
|
self.handle.wait()
|
|
self.handle = None
|
|
|
|
b, l, c = hidden_states.shape
|
|
kv_shape = (b, l, self.module.to_k.out_features * 2)
|
|
if self.patched_parallelism_size > 1 and (self.buffer_list is None or self.buffer_list[0].shape != kv_shape):
|
|
|
|
self.buffer_list = [
|
|
torch.empty(kv_shape, dtype=hidden_states.dtype, device=get_current_device())
|
|
for _ in range(self.patched_parallelism_size)
|
|
]
|
|
|
|
self.counter = 0
|
|
|
|
output = self._forward(hidden_states, scale=scale)
|
|
|
|
self.counter += 1
|
|
return output
|