2023-02-28 10:07:24 +00:00
# This code from NVIDIA Megatron:
# with minor changes.
2022-11-17 05:42:33 +00:00
import enum
2021-12-21 04:19:52 +00:00
import torch
import torch . nn as nn
2024-01-25 09:01:48 +00:00
from colossalai . kernel . kernel_loader import ScaledMaskedSoftmaxLoader , ScaledUpperTriangleMaskedSoftmaxLoader
2023-02-03 01:47:13 +00:00
2024-04-07 04:06:27 +00:00
# NOTE: These kernels are compiled on specific GPU arch and not widely applicable.
# try:
# from colossalai._C import scaled_masked_softmax as scaled_masked_softmax, scaled_upper_triangle_masked_softmax_cuda as scaled_upper_triang_masked_softmax
# except ImportError:
scaled_masked_softmax = None
scaled_upper_triang_masked_softmax = None
2021-12-21 04:19:52 +00:00
class AttnMaskType ( enum . Enum ) :
padding = 1
causal = 2
2023-08-04 05:46:22 +00:00
paddedcausal = 3
2021-12-21 04:19:52 +00:00
class ScaledUpperTriangMaskedSoftmax ( torch . autograd . Function ) :
"""
Fused operation which performs following three operations in sequence
2022-01-21 02:44:30 +00:00
1. Scale the tensor .
2. Apply upper triangular mask ( typically used in gpt models ) .
3. Perform softmax .
2021-12-21 04:19:52 +00:00
"""
@staticmethod
def forward ( ctx , inputs , scale ) :
2023-02-03 01:47:13 +00:00
global scaled_upper_triang_masked_softmax
if scaled_upper_triang_masked_softmax :
2024-01-25 09:01:48 +00:00
scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader ( ) . load ( )
2021-12-21 04:19:52 +00:00
scale_t = torch . tensor ( [ scale ] )
2022-12-30 01:58:00 +00:00
softmax_results = scaled_upper_triang_masked_softmax . forward ( inputs , scale_t [ 0 ] )
2021-12-21 04:19:52 +00:00
ctx . save_for_backward ( softmax_results , scale_t )
return softmax_results
@staticmethod
def backward ( ctx , output_grads ) :
softmax_results , scale_t = ctx . saved_tensors
2022-12-30 01:58:00 +00:00
input_grads = scaled_upper_triang_masked_softmax . backward ( output_grads , softmax_results , scale_t [ 0 ] )
2021-12-21 04:19:52 +00:00
return input_grads , None
class ScaledMaskedSoftmax ( torch . autograd . Function ) :
"""
Fused operation which performs following three operations in sequence
2022-01-21 02:44:30 +00:00
1. Scale the tensor .
2. Apply the mask .
3. Perform softmax .
2021-12-21 04:19:52 +00:00
"""
@staticmethod
def forward ( ctx , inputs , mask , scale ) :
scale_t = torch . tensor ( [ scale ] )
2023-02-03 01:47:13 +00:00
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None :
2024-01-25 09:01:48 +00:00
scaled_masked_softmax = ScaledMaskedSoftmaxLoader ( ) . load ( )
2023-02-03 01:47:13 +00:00
2023-01-07 10:23:02 +00:00
softmax_results = scaled_masked_softmax . forward ( inputs , mask , scale_t [ 0 ] )
2021-12-21 04:19:52 +00:00
ctx . save_for_backward ( softmax_results , scale_t )
return softmax_results
@staticmethod
def backward ( ctx , output_grads ) :
softmax_results , scale_t = ctx . saved_tensors
2023-01-07 10:23:02 +00:00
input_grads = scaled_masked_softmax . backward ( output_grads , softmax_results , scale_t [ 0 ] )
2023-02-03 01:47:13 +00:00
return input_grads , None , None , None
2021-12-21 04:19:52 +00:00
class FusedScaleMaskSoftmax ( nn . Module ) :
"""
2022-01-21 02:44:30 +00:00
Fused operation : scaling + mask + softmax
2021-12-21 04:19:52 +00:00
Arguments :
2022-01-21 02:44:30 +00:00
input_in_fp16 : Flag to indicate if input in fp16 data format .
input_in_bf16 : Flag to indicate if input in bf16 data format .
attn_mask_type : Attention mask type ( pad or causal )
scaled_masked_softmax_fusion : Flag to indicate user want to use softmax fusion
mask_func : Mask function to be applied .
softmax_in_fp32 : If True , softmax in performed at fp32 precision .
scale : Scaling factor used in input tensor scaling .
2021-12-21 04:19:52 +00:00
"""
def __init__ (
self ,
input_in_fp16 ,
input_in_bf16 ,
attn_mask_type ,
scaled_masked_softmax_fusion ,
mask_func ,
softmax_in_fp32 ,
scale ,
) :
super ( FusedScaleMaskSoftmax , self ) . __init__ ( )
self . input_in_fp16 = input_in_fp16
self . input_in_bf16 = input_in_bf16
2023-09-19 06:20:26 +00:00
assert not (
self . input_in_fp16 and self . input_in_bf16
) , " both fp16 and bf16 flags cannot be active at the same time. "
2021-12-21 04:19:52 +00:00
self . input_in_float16 = self . input_in_fp16 or self . input_in_bf16
self . attn_mask_type = attn_mask_type
self . scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self . mask_func = mask_func
self . softmax_in_fp32 = softmax_in_fp32
self . scale = scale
2023-09-19 06:20:26 +00:00
assert self . scale is None or softmax_in_fp32 , " softmax should be in fp32 when scaled "
2021-12-21 04:19:52 +00:00
def forward ( self , input , mask ) :
# [b, np, sq, sk]
assert input . dim ( ) == 4
if self . is_kernel_available ( mask , * input . size ( ) ) :
return self . forward_fused_softmax ( input , mask )
else :
return self . forward_torch_softmax ( input , mask )
def is_kernel_available ( self , mask , b , np , sq , sk ) :
attn_batches = b * np
2023-09-19 06:20:26 +00:00
if (
self . scaled_masked_softmax_fusion # user want to fuse
and self . input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk < = 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
) :
2021-12-21 04:19:52 +00:00
if 0 < = sk < = 2048 :
batch_per_block = self . get_batch_per_block ( sq , sk , b , np )
2023-08-04 05:46:22 +00:00
if self . attn_mask_type . value > 1 :
2021-12-21 04:19:52 +00:00
if attn_batches % batch_per_block == 0 :
return True
else :
if sq % batch_per_block == 0 :
return True
return False
def forward_fused_softmax ( self , input , mask ) :
b , np , sq , sk = input . size ( )
scale = self . scale if self . scale is not None else 1.0
2023-08-04 05:46:22 +00:00
if self . attn_mask_type . value > 1 :
2021-12-21 04:19:52 +00:00
assert sq == sk , " causal mask is only for self attention "
# input is 3D tensor (attn_batches, sq, sk)
input = input . view ( - 1 , sq , sk )
probs = ScaledUpperTriangMaskedSoftmax . apply ( input , scale )
return probs . view ( b , np , sq , sk )
else :
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax . apply ( input , mask , scale )
def forward_torch_softmax ( self , input , mask ) :
if self . input_in_float16 and self . softmax_in_fp32 :
input = input . float ( )
if self . scale is not None :
input = input * self . scale
mask_output = self . mask_func ( input , mask ) if mask is not None else input
probs = torch . nn . Softmax ( dim = - 1 ) ( mask_output )
if self . input_in_float16 and self . softmax_in_fp32 :
if self . input_in_fp16 :
probs = probs . half ( )
else :
probs = probs . bfloat16 ( )
return probs
2023-01-13 06:40:05 +00:00
def get_batch_per_block ( self , sq , sk , b , np ) :
2023-03-10 06:27:09 +00:00
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None :
2024-01-25 09:01:48 +00:00
scaled_masked_softmax = ScaledMaskedSoftmaxLoader ( ) . load ( )
2023-03-10 06:27:09 +00:00
2023-02-03 01:47:13 +00:00
return scaled_masked_softmax . get_batch_per_block ( sq , sk , b , np )