from typing import Optional, Tuple import torch from ..registry import meta_profiler_module # TODO: This is hard to compute memory cost @meta_profiler_module.register(torch.nn.MultiheadAttention) def torch_nn_msa(self: torch.nn.MultiheadAttention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, need_weights: bool = True, attn_mask: Optional[torch.Tensor] = None, average_attn_weights: bool = True) -> Tuple[int, int]: if getattr(self, 'batch_first', False): batch_size = query.shape[0] len_idx = 1 else: batch_size = query.shape[1] len_idx = 0 dim_idx = 2 qdim = query.shape[dim_idx] kdim = key.shape[dim_idx] vdim = value.shape[dim_idx] qlen = query.shape[len_idx] klen = key.shape[len_idx] vlen = value.shape[len_idx] num_heads = self.num_heads assert qdim == self.embed_dim if self.kdim is None: assert kdim == qdim if self.vdim is None: assert vdim == qdim flops = 0 macs = 0 # Q scaling flops += qlen * qdim # Initial projections flops += 2 * ((qlen * qdim * qdim) # QW + (klen * kdim * kdim) # KW + (vlen * vdim * vdim) # VW ) macs += ((qlen * qdim * qdim) # QW + (klen * kdim * kdim) # KW + (vlen * vdim * vdim) # VW ) if self.in_proj_bias is not None: flops += (qlen + klen + vlen) * qdim # attention heads: scale, matmul, softmax, matmul qk_head_dim = qdim // num_heads v_head_dim = vdim // num_heads head_flops = ( 2 * (qlen * klen * qk_head_dim) # QK^T + (qlen * klen) # softmax + 2 * (qlen * klen * v_head_dim) # AV ) head_macs = ((qlen * klen * qk_head_dim) # QK^T + 2 * (qlen * klen * v_head_dim) # AV ) flops += num_heads * head_flops macs += num_heads * head_flops # final projection, bias is always enabled flops += qlen * vdim * (vdim + 1) flops *= batch_size macs *= batch_size return flops, macs