#!/usr/bin/env python # -*- encoding: utf-8 -*- import math import torch from torch import nn as nn, Tensor, distributed as dist from torch.nn.init import _calculate_fan_in_and_fan_out from colossalai.context import seed, ParallelMode from colossalai.core import global_context as gpc from colossalai.registry import LAYERS from colossalai.utils import checkpoint from colossalai.utils import get_current_device from ._operation import AllGatherLast, SplitFirst from ._utils import assert_tesseract_initialization, \ get_tesseract_dim_dep_from_env from .layers import Linear2p5D from ..base_layer import ParallelLayer from ..fused_bias_gelu import bias_gelu_impl from .._common_utils import (ACT2FN, divide, to_2tuple, set_tensor_parallel_attribute_by_partition) @LAYERS.register_module class ViTMLP2p5D(ParallelLayer): """MLP layer for 2.5D parallel Vision Transformer :param in_features: size of each input sample :type in_features: int :param mlp_ratio: hidden size of MLP divided by embedding dim :type mlp_ratio: int :param act_func: activation function, defaults to 'gelu' :type act_func: str, optional :param dropout_prob: dropout probability, defaults to 0. :type dropout_prob: float, optional :param dtype: The dtype of parameters, defaults to None :type dtype: torch.dtype, optional :param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False` :type checkpoint: bool, optional """ def __init__(self, in_features: int, mlp_ratio: int, act_func: str = 'gelu', dropout_prob: float = 0., dtype=None, checkpoint: bool = False, weight_init='torch' ): super().__init__() assert_tesseract_initialization() self.in_features = in_features self.mlp_ratio = mlp_ratio self.checkpoint = checkpoint assert weight_init in ('torch', 'jax') if act_func == 'fused_gelu': self.act = bias_gelu_impl skip_dense_1_add_bias = True else: self.act = ACT2FN[act_func] skip_dense_1_add_bias = False # Project to mlp_ratio * h. self.dense_1 = Linear2p5D( self.in_features, self.mlp_ratio * self.in_features, dtype=dtype, init_weight=weight_init, init_bias=weight_init, skip_bias_add=skip_dense_1_add_bias ) self.act = ACT2FN[act_func] # Project back to h. self.dense_2 = Linear2p5D( self.mlp_ratio * self.in_features, self.in_features, dtype=dtype, init_weight=weight_init, init_bias=weight_init ) self.dropout = nn.Dropout(dropout_prob) def _forward(self, hidden_states: Tensor) -> Tensor: if self.act == bias_gelu_impl: intermediate_output, bias = self.dense_1(hidden_states) intermediate_output = self.act(intermediate_output, bias) else: intermediate_output = self.dense_1(hidden_states) intermediate_output = self.act(intermediate_output) with seed(ParallelMode.TENSOR): intermediate_output = self.dropout(intermediate_output) output = self.dense_2(intermediate_output) with seed(ParallelMode.TENSOR): output = self.dropout(output) return output def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: return checkpoint(self._forward, hidden_states) def forward(self, hidden_states: Tensor) -> Tensor: if self.checkpoint: return self._checkpoint_forward(hidden_states) else: return self._forward(hidden_states) @LAYERS.register_module class ViTSelfAttention2p5D(ParallelLayer): """Self-attention layer for 2.5D parallel Vision Transformer :param hidden_size: hidden size :type hidden_size: int :param num_attention_heads: number of attention heads :type num_attention_heads: int :param attention_dropout_prob: dropout probability for attention layers :type attention_dropout_prob: float :param hidden_dropout_prob: dropout probability for hidden layers :type hidden_dropout_prob: float :param dtype: dtype of parameters, defaults to None :type dtype: torch.dtype, optional :param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False` :type checkpoint: bool, optional """ def __init__(self, hidden_size, num_attention_heads, attention_dropout_prob, hidden_dropout_prob, dtype=None, checkpoint: bool = False, weight_init='torch' ): super().__init__() assert_tesseract_initialization() self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.hidden_size = hidden_size self.num_attention_heads = divide( num_attention_heads, self.tesseract_dim) # * self.attention_head_size = divide(hidden_size, num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.checkpoint = checkpoint assert weight_init in ('torch', 'jax') if weight_init == 'jax': self.init_bias = 'zero' else: self.init_bias = weight_init self.query_key_value = Linear2p5D( hidden_size, 3 * hidden_size, dtype=dtype, init_weight=weight_init, init_bias=self.init_bias ) self.attention_dropout = nn.Dropout(attention_dropout_prob) self.dense = Linear2p5D( hidden_size, hidden_size, dtype=dtype, init_weight=weight_init, init_bias=self.init_bias ) self.dropout = nn.Dropout(hidden_dropout_prob) self.softmax = nn.Softmax(dim=-1) def _forward(self, hidden_states: Tensor) -> Tensor: query_key_value = self.query_key_value(hidden_states) new_qkv_shape = query_key_value.shape[:-1] + \ (self.num_attention_heads, 3 * self.attention_head_size) query_key_value = query_key_value.view(new_qkv_shape) query_key_value = query_key_value.permute((0, 2, 1, 3)) query_layer, key_layer, value_layer = torch.chunk( query_key_value, 3, dim=-1) attention_scores = torch.matmul( query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / \ math.sqrt(self.attention_head_size) attention_probs = self.softmax(attention_scores) with seed(ParallelMode.TENSOR): attention_probs = self.attention_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.transpose(1, 2) new_context_layer_shape = context_layer.size()[ :-2] + (self.all_head_size,) context_layer = context_layer.reshape(new_context_layer_shape) output = self.dense(context_layer) with seed(ParallelMode.TENSOR): output = self.dropout(output) return output def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor: return checkpoint(self._forward, hidden_states) def forward(self, hidden_states: Tensor) -> Tensor: if self.checkpoint: return self._checkpoint_forward(hidden_states) else: return self._forward(hidden_states) @LAYERS.register_module class ViTHead2p5D(ParallelLayer): """Output layer for 2.5D parallel Vision Transformer :param hidden_size: hidden size :type hidden_size: int :param num_classes: number of classes :type num_classes: int :param dtype: dtype of parameters, defaults to None :type dtype: torch.dtype, optional """ def __init__(self, hidden_size, num_classes, dtype=None, weight_init='torch' ): super().__init__() assert_tesseract_initialization() assert weight_init in ('torch', 'jax') if weight_init == 'jax': self.init_weight = 'zero' self.init_bias = 'zero' else: self.init_weight = weight_init self.init_bias = weight_init self.linear = Linear2p5D( hidden_size, num_classes, dtype=dtype, init_weight=self.init_weight, init_bias=self.init_bias ) def forward(self, x: Tensor) -> Tensor: x = x[:, 0] x = self.linear(x) return x @LAYERS.register_module class ViTPatchEmbedding2p5D(ParallelLayer): """ 2.5D Image to Patch Embedding :param img_size: iamge size :type img_size: int :param patch_size: patch size :type patch_size: int :param embed_dim: dimension of embedding :type embed_dim: int :param in_chans: number of channels of input image, defaults to 3 :type in_chans: int, optional :param flatten: whether to flatten output tensor, defaults to True :type flatten: bool, optional """ def __init__(self, img_size, patch_size, embed_dim, in_chans=3, flatten=True, weight_init='torch'): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) assert_tesseract_initialization() self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.embed_dim = embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2) # * with seed(ParallelMode.TENSOR): self.proj = nn.Conv2d(in_chans, self.embed_dim, kernel_size=patch_size, stride=patch_size, device=get_current_device() ) self._set_tensor_parallel_attribute() if weight_init == 'jax': with seed(ParallelMode.TENSOR): fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight) std = math.sqrt(1.0 / fan_in) nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978) nn.init.zeros_(self.proj.bias) def _set_tensor_parallel_attribute(self): num_partition = gpc.get_world_size(ParallelMode.TENSOR) set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition) set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition) def forward(self, x: Tensor) -> Tensor: B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC return x @LAYERS.register_module class ViTInputSplitter2p5D(ParallelLayer): """Split the input tensor for 2D parallel Vision Transformer """ def __init__(self): super().__init__() assert_tesseract_initialization() self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() def forward(self, x: Tensor) -> Tensor: x = AllGatherLast.apply( x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) x = SplitFirst.apply( x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) return x @LAYERS.register_module class ViTTokenFuser2p5D(ParallelLayer): """ Fuse cls token and pos embedding to the input :param img_size: image size :type img_size: int :param patch_size: patch size :type patch_size: int :param embed_dim: dimension of embedding :type embed_dim: int :param drop_rate: dropout probability, defaults to 0. :type drop_rate: float, optional """ def __init__(self, img_size, patch_size, embed_dim, drop_rate=0. ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) assert_tesseract_initialization() self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env() self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.embed_dim = embed_dim self.cls_token = nn.Parameter(torch.zeros( (1, 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)), device=get_current_device())) self.pos_embed = nn.Parameter(torch.empty( (1, self.num_patches + 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)), device=get_current_device())) with seed(ParallelMode.TENSOR): nn.init.trunc_normal_(self.pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) self._set_tensor_parallel_attribute() def _set_tensor_parallel_attribute(self): num_partition = gpc.get_world_size(ParallelMode.TENSOR) set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition) set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition) def _broadcast_params(self, param) -> None: " broadcast to all column ranks for data consistency " if self.tesseract_dep > 1: xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ) xz_group = gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ) dist.broadcast(param, src=xz_rank[0], group=xz_group) def _sync_grad_hook(self, grad) -> None: dist.all_reduce(grad, group=gpc.get_group( ParallelMode.PARALLEL_2P5D_XZ)) grad = grad / self.tesseract_dim # / self.tesseract_dep # * return grad def forward(self, x: Tensor) -> Tensor: # stole cls_tokens impl from Phil Wang, thanks cls_token = AllGatherLast.apply( self.cls_token, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) cls_token = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1) pos_embed = AllGatherLast.apply( self.pos_embed, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) x = x + pos_embed with seed(ParallelMode.TENSOR): x = self.pos_drop(x) return x