mirror of https://github.com/hpcaitech/ColossalAI
[Inference/Refactor] Delete Duplicated code and refactor vec_copy utils and reduce utils (#5593)
* delete duplicated code and refactor vec_copy utils and reduce utils * delete unused header filepull/5531/head
parent
a21912339a
commit
d4cb023b62
|
@ -1,11 +0,0 @@
|
|||
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from .multihead_attention import MultiHeadAttention
|
||||
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
|
||||
|
||||
__all__ = [
|
||||
"LayerNorm",
|
||||
"MultiHeadAttention",
|
||||
"FusedScaleMaskSoftmax",
|
||||
"ScaledUpperTriangMaskedSoftmax",
|
||||
"AttnMaskType",
|
||||
]
|
|
@ -4,6 +4,10 @@
|
|||
#include "utils/vec_copy.h"
|
||||
#include "../common/micros.h"
|
||||
|
||||
using colossalAI::cuda::utils::copy_vector;
|
||||
using colossalAI::cuda::utils::get_vec_size;
|
||||
|
||||
|
||||
template<typename scalar_t, bool Aligned, int VecSize>
|
||||
__global__ void context_kv_cache_memcpy_kernel(
|
||||
const scalar_t* __restrict__ key,
|
||||
|
|
|
@ -4,6 +4,9 @@
|
|||
#include "utils/vec_copy.h"
|
||||
#include "../common/micros.h"
|
||||
|
||||
using colossalAI::cuda::utils::copy_vector;
|
||||
using colossalAI::cuda::utils::get_vec_size;
|
||||
|
||||
template<typename scalar_t, bool Aligned, int VecSize>
|
||||
__global__ void decode_kv_cache_memcpy_kernel(
|
||||
const scalar_t* __restrict__ key,
|
||||
|
|
|
@ -30,17 +30,25 @@ struct CastFunctor : public std::unary_function<From, To> {
|
|||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y),
|
||||
DEVICE)
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half, __float2half(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16,
|
||||
__float2bfloat16(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat162,
|
||||
__float2bfloat162_rn(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val),
|
||||
DEVICE)
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float,
|
||||
__bfloat162float(val), DEVICE)
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, nv_bfloat162,
|
||||
__float2bfloat162_rn(val), DEVICE)
|
||||
|
||||
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
|
||||
} // namespace funcs
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
namespace colossalAI {
|
||||
namespace cuda {
|
||||
namespace utils {
|
||||
namespace funcs {
|
||||
|
||||
const float kReduceFloatInfNeg = -100000000.f;
|
||||
const float kReduceFloatInfPos = 100000000.f;
|
||||
|
@ -88,93 +88,6 @@ __forceinline__ __device__ void block_reduce(T* pval) {
|
|||
#undef COLOSSAL_WARP_REDUCE_IMPL
|
||||
#undef COLOSSAL_BLOCK_REDUCE_IMPL
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(
|
||||
T* x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize =
|
||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = x[tid] + x[tid + 32];
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
||||
T* x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize =
|
||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final =
|
||||
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace funcs
|
||||
} // namespace cuda
|
||||
} // namespace colossalAI
|
|
@ -15,7 +15,7 @@ namespace funcs {
|
|||
|
||||
// Note(LiuYang): As a retrieved table to check which operation is supported
|
||||
// already
|
||||
enum class UnaryOpType { kLog2Ceil = 0 };
|
||||
enum class UnaryOpType { kLog2Ceil = 0, kAbs };
|
||||
|
||||
// Note(LiuYang): Implementation of common and simple unary operators should be
|
||||
// placed here, otherwise, they should be placed in a new file under functors
|
||||
|
@ -31,6 +31,9 @@ struct UnaryOpFunctor;
|
|||
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
|
||||
};
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(
|
||||
T, T, UnaryOpType::kAbs, HOSTDEVICE, { return std::abs(val); }, typename T)
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil,
|
||||
HOSTDEVICE, {
|
||||
int log2_value = 0;
|
||||
|
|
|
@ -6,6 +6,9 @@
|
|||
#include "../common/micros.h"
|
||||
#include "../common/mp_type_traits.h"
|
||||
|
||||
using colossalAI::cuda::utils::copy_vector;
|
||||
using colossalAI::cuda::utils::get_vec_size;
|
||||
|
||||
template <typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__device__ void apply_emb_rotary_compute(
|
||||
scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr,
|
||||
|
|
|
@ -3,7 +3,10 @@
|
|||
|
||||
#include "utils/vec_copy.h"
|
||||
#include "../common/micros.h"
|
||||
#include "stdio.h"
|
||||
|
||||
using colossalAI::cuda::utils::copy_vector;
|
||||
using colossalAI::cuda::utils::get_vec_size;
|
||||
|
||||
|
||||
template <typename scalar_t, bool Aligned, int VecSize>
|
||||
__device__ void apply_cos_and_sin_memcopy(
|
||||
|
|
|
@ -4,11 +4,11 @@
|
|||
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include "block_reduce.h"
|
||||
#include "funcs/reduce_function.h"
|
||||
|
||||
|
||||
using colossalAI::cuda::utils::block_reduce;
|
||||
using colossalAI::cuda::utils::ReduceType;
|
||||
using colossalAI::cuda::funcs::block_reduce;
|
||||
using colossalAI::cuda::funcs::ReduceType;
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
||||
|
|
|
@ -12,14 +12,98 @@
|
|||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "../common/micros.h"
|
||||
#include "include/block_reduce.h"
|
||||
#include "funcs/reduce_function.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
using colossalAI::cuda::utils::block_reduce;
|
||||
using colossalAI::cuda::utils::reduce_block_into_lanes;
|
||||
using colossalAI::cuda::utils::reduce_block_into_lanes_max_op;
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(
|
||||
T* x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize =
|
||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = x[tid] + x[tid + 32];
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
||||
T* x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize =
|
||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final =
|
||||
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
|
|
|
@ -5,39 +5,20 @@
|
|||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <stdio.h>
|
||||
|
||||
|
||||
#include "block_reduce.h"
|
||||
#include "../common/micros.h"
|
||||
#include "funcs/cast_functor.h"
|
||||
#include "funcs/binary_functor.h"
|
||||
#include "funcs/reduce_function.h"
|
||||
#include "utils/vec_type_traits.h"
|
||||
|
||||
using colossalAI::cuda::utils::block_reduce;
|
||||
using colossalAI::cuda::utils::ReduceType;
|
||||
using colossalAI::cuda::funcs::block_reduce;
|
||||
using colossalAI::cuda::funcs::ReduceType;
|
||||
using colossalAI::cuda::funcs::CastFunctor;
|
||||
using colossalAI::cuda::funcs::BinaryOpFunctor;
|
||||
using colossalAI::cuda::funcs::BinaryOpType;
|
||||
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
#define TYPE_CONVERTER_SPECIALIZATION(FROM, TO) \
|
||||
template <> \
|
||||
struct TypeConverter<FROM> { \
|
||||
using Type = TO; \
|
||||
};
|
||||
|
||||
TYPE_CONVERTER_SPECIALIZATION(half2, at::Half)
|
||||
TYPE_CONVERTER_SPECIALIZATION(at::Half, half2)
|
||||
TYPE_CONVERTER_SPECIALIZATION(__nv_bfloat162, at::BFloat16)
|
||||
TYPE_CONVERTER_SPECIALIZATION(at::BFloat16, __nv_bfloat162)
|
||||
|
||||
#undef TYPE_CONVERTER_SPECIALIZATION
|
||||
using colossalAI::cuda::utils::VecTypeTrait;
|
||||
|
||||
// optimized for half and bf16
|
||||
template<typename scalar_t, int unroll_factor>
|
||||
|
@ -48,7 +29,7 @@ __global__ void rms_layernorm_kernel(
|
|||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using scalar2_t = typename TypeConverter<scalar_t>::Type;
|
||||
using scalar2_t = typename VecTypeTrait<scalar_t, 2>::Type;
|
||||
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t;
|
||||
__shared__ float s_variance;
|
||||
|
||||
|
@ -134,7 +115,7 @@ __global__ void fused_add_rms_layernorm_kernel(
|
|||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using scalar2_t = typename TypeConverter<scalar_t>::Type;
|
||||
using scalar2_t = typename VecTypeTrait<scalar_t, 2>::Type;
|
||||
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kAdd> add_scalar2t;
|
||||
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t;
|
||||
|
||||
|
|
|
@ -16,13 +16,14 @@
|
|||
|
||||
#include "../common/micros.h"
|
||||
#include "utils/vec_copy.h"
|
||||
#include "include/block_reduce.h"
|
||||
#include "funcs/reduce_function.h"
|
||||
#include "funcs/unary_functor.h"
|
||||
|
||||
using colossalAI::cuda::funcs::UnaryOpFunctor;
|
||||
using colossalAI::cuda::funcs::UnaryOpType;
|
||||
using colossalAI::cuda::utils::warp_reduce;
|
||||
using colossalAI::cuda::utils::ReduceType;
|
||||
using colossalAI::cuda::funcs::warp_reduce;
|
||||
using colossalAI::cuda::funcs::ReduceType;
|
||||
using colossalAI::cuda::utils::copy_vector;
|
||||
|
||||
|
||||
/*
|
||||
|
|
|
@ -16,13 +16,15 @@
|
|||
|
||||
#include "../common/micros.h"
|
||||
#include "utils/vec_copy.h"
|
||||
#include "include/block_reduce.h"
|
||||
#include "funcs/reduce_function.h"
|
||||
#include "funcs/unary_functor.h"
|
||||
|
||||
using colossalAI::cuda::funcs::UnaryOpFunctor;
|
||||
using colossalAI::cuda::funcs::UnaryOpType;
|
||||
using colossalAI::cuda::utils::warp_reduce;
|
||||
using colossalAI::cuda::utils::ReduceType;
|
||||
using colossalAI::cuda::funcs::warp_reduce;
|
||||
using colossalAI::cuda::funcs::ReduceType;
|
||||
using colossalAI::cuda::utils::copy_vector;
|
||||
using colossalAI::cuda::utils::copy_zero_vector;
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../funcs/cast_functor.h"
|
||||
#include "vec_type_traits.h"
|
||||
|
||||
namespace colossalAI {
|
||||
namespace cuda {
|
||||
namespace utils {
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__device__ __inline__ void copy_vector(T *dst, const T *src) {
|
||||
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
|
||||
|
@ -26,7 +30,8 @@ __device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
|
|||
template <typename T, int VecSize>
|
||||
__device__ __inline__ void copy_zero_vector(T *dst) {
|
||||
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
|
||||
*(reinterpret_cast<VT *>(dst)) = {0.0};
|
||||
*(reinterpret_cast<VT *>(dst)) =
|
||||
colossalAI::cuda::funcs::CastFunctor<float, VT>()(0.0f);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -50,3 +55,7 @@ int get_vec_size(const torch::Tensor &tensor) {
|
|||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace cuda
|
||||
} // namespace colossalAI
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cfloat>
|
||||
|
||||
|
@ -20,12 +21,14 @@ struct VecTypeTrait {};
|
|||
};
|
||||
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 2, float)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 4, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 8, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 2, float)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 4, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 8, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 1, __nv_bfloat16)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2, __nv_bfloat162)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 8, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 1, half)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 2, half2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
|
||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, float4)
|
||||
|
|
|
@ -1,190 +0,0 @@
|
|||
# This code from NVIDIA Megatron:
|
||||
# with minor changes.
|
||||
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader
|
||||
|
||||
try:
|
||||
from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
scaled_masked_softmax = None
|
||||
scaled_upper_triang_masked_softmax = None
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
padding = 1
|
||||
causal = 2
|
||||
paddedcausal = 3
|
||||
|
||||
|
||||
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
"""
|
||||
Fused operation which performs following three operations in sequence
|
||||
|
||||
1. Scale the tensor.
|
||||
2. Apply upper triangular mask (typically used in gpt models).
|
||||
3. Perform softmax.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
global scaled_upper_triang_masked_softmax
|
||||
if scaled_upper_triang_masked_softmax:
|
||||
scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load()
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
|
||||
return input_grads, None
|
||||
|
||||
|
||||
class ScaledMaskedSoftmax(torch.autograd.Function):
|
||||
"""
|
||||
Fused operation which performs following three operations in sequence
|
||||
|
||||
1. Scale the tensor.
|
||||
2. Apply the mask.
|
||||
3. Perform softmax.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, mask, scale):
|
||||
scale_t = torch.tensor([scale])
|
||||
|
||||
# build and load kernel if not pre-built
|
||||
global scaled_masked_softmax
|
||||
if scaled_masked_softmax is None:
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
|
||||
|
||||
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
|
||||
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
return input_grads, None, None, None
|
||||
|
||||
|
||||
class FusedScaleMaskSoftmax(nn.Module):
|
||||
"""
|
||||
Fused operation: scaling + mask + softmax
|
||||
|
||||
Arguments:
|
||||
input_in_fp16: Flag to indicate if input in fp16 data format.
|
||||
input_in_bf16: Flag to indicate if input in bf16 data format.
|
||||
attn_mask_type: Attention mask type (pad or causal)
|
||||
scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion
|
||||
mask_func: Mask function to be applied.
|
||||
softmax_in_fp32: If True, softmax in performed at fp32 precision.
|
||||
scale: Scaling factor used in input tensor scaling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_in_fp16,
|
||||
input_in_bf16,
|
||||
attn_mask_type,
|
||||
scaled_masked_softmax_fusion,
|
||||
mask_func,
|
||||
softmax_in_fp32,
|
||||
scale,
|
||||
):
|
||||
super(FusedScaleMaskSoftmax, self).__init__()
|
||||
self.input_in_fp16 = input_in_fp16
|
||||
self.input_in_bf16 = input_in_bf16
|
||||
assert not (
|
||||
self.input_in_fp16 and self.input_in_bf16
|
||||
), "both fp16 and bf16 flags cannot be active at the same time."
|
||||
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
|
||||
self.attn_mask_type = attn_mask_type
|
||||
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
||||
self.mask_func = mask_func
|
||||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
|
||||
|
||||
def forward(self, input, mask):
|
||||
# [b, np, sq, sk]
|
||||
assert input.dim() == 4
|
||||
|
||||
if self.is_kernel_available(mask, *input.size()):
|
||||
return self.forward_fused_softmax(input, mask)
|
||||
else:
|
||||
return self.forward_torch_softmax(input, mask)
|
||||
|
||||
def is_kernel_available(self, mask, b, np, sq, sk):
|
||||
attn_batches = b * np
|
||||
|
||||
if (
|
||||
self.scaled_masked_softmax_fusion # user want to fuse
|
||||
and self.input_in_float16 # input must be fp16
|
||||
and mask is not None # mask tensor must not be None
|
||||
and 16 < sk <= 2048 # sk must be 16 ~ 2048
|
||||
and sq % 4 == 0 # sq must be divisor of 4
|
||||
and attn_batches % 4 == 0 # np * b must be divisor of 4
|
||||
):
|
||||
if 0 <= sk <= 2048:
|
||||
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
|
||||
|
||||
if self.attn_mask_type.value > 1:
|
||||
if attn_batches % batch_per_block == 0:
|
||||
return True
|
||||
else:
|
||||
if sq % batch_per_block == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def forward_fused_softmax(self, input, mask):
|
||||
b, np, sq, sk = input.size()
|
||||
scale = self.scale if self.scale is not None else 1.0
|
||||
|
||||
if self.attn_mask_type.value > 1:
|
||||
assert sq == sk, "causal mask is only for self attention"
|
||||
|
||||
# input is 3D tensor (attn_batches, sq, sk)
|
||||
input = input.view(-1, sq, sk)
|
||||
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
|
||||
return probs.view(b, np, sq, sk)
|
||||
else:
|
||||
# input is 4D tensor (b, np, sq, sk)
|
||||
return ScaledMaskedSoftmax.apply(input, mask, scale)
|
||||
|
||||
def forward_torch_softmax(self, input, mask):
|
||||
if self.input_in_float16 and self.softmax_in_fp32:
|
||||
input = input.float()
|
||||
|
||||
if self.scale is not None:
|
||||
input = input * self.scale
|
||||
mask_output = self.mask_func(input, mask) if mask is not None else input
|
||||
probs = torch.nn.Softmax(dim=-1)(mask_output)
|
||||
|
||||
if self.input_in_float16 and self.softmax_in_fp32:
|
||||
if self.input_in_fp16:
|
||||
probs = probs.half()
|
||||
else:
|
||||
probs = probs.bfloat16()
|
||||
|
||||
return probs
|
||||
|
||||
def get_batch_per_block(self, sq, sk, b, np):
|
||||
# build and load kernel if not pre-built
|
||||
global scaled_masked_softmax
|
||||
if scaled_masked_softmax is None:
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
||||
|
||||
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
Loading…
Reference in New Issue