mirror of https://github.com/hpcaitech/ColossalAI
add cast and op_functor for cuda build-in types (#5546)
parent
4bb5d8923a
commit
7ebdf48ac5
|
@ -0,0 +1,74 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "../utils/micros.h"
|
||||
|
||||
// Note(LiuYang): This file provides base math operation for data type
|
||||
// include POD and cuda built-in type such as half and __nv_bfloat16
|
||||
|
||||
namespace colossalAI {
|
||||
namespace cuda {
|
||||
namespace funcs {
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = at::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<at::Half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = at::BFloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<at::BFloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
template <typename From, typename To>
|
||||
struct CastFunctor : public std::unary_function<From, To> {
|
||||
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
|
||||
};
|
||||
|
||||
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMT, \
|
||||
FUNCTION_MODIFIER) \
|
||||
template <> \
|
||||
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
|
||||
FUNCTION_MODIFIER TO operator()(FROM val) { return STMT; } \
|
||||
};
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val),
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, nv_bfloat162,
|
||||
__float2bfloat162_rn(val), DEVICE)
|
||||
|
||||
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
|
||||
} // namespace funcs
|
||||
} // namespace cuda
|
||||
} // namespace colossalAI
|
|
@ -1,31 +1,91 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "../utils/micros.h"
|
||||
|
||||
namespace colossalAI {
|
||||
namespace cuda {
|
||||
namespace funcs {
|
||||
|
||||
enum class BinaryOpType { kAdd = 0, kMinus, kMul, KDiv, kMax, KMin };
|
||||
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
|
||||
|
||||
template <typename T, BinaryOpType Op>
|
||||
// Note(LiuYang): This file provides base math operation for data type
|
||||
// include POD and cuda built-in type such as half and __nv_bfloat16
|
||||
template <typename LT, typename RT, typename RET, BinaryOpType Op>
|
||||
struct BinaryOpFunctor;
|
||||
|
||||
template <typename T>
|
||||
struct BinaryOpFunctor<T, BinaryOpType::kAdd>
|
||||
: public std::binary_function<T, T, T> {
|
||||
__host__ __device__ T operator()(T lhs, T rhs) { return lhs + rhs; }
|
||||
};
|
||||
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \
|
||||
FUNCTION_MODIFIER, ARGS...) \
|
||||
template <ARGS> \
|
||||
struct BinaryOpFunctor<T, T, T, BINARY_OP_TYPE> \
|
||||
: public std::binary_function<T, T, T> { \
|
||||
FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BinaryOpFunctor<T, BinaryOpType::kMax>
|
||||
: public std::binary_function<T, T, T> {
|
||||
__host__ __device__ T operator()(T lhs, T rhs) { return max(lhs, rhs); }
|
||||
};
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs,
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs,
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs,
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs,
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs),
|
||||
HOSTDEVICE, typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs),
|
||||
HOSTDEVICE, typename T)
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd,
|
||||
__hadd(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd,
|
||||
__hadd2(lhs, rhs), DEVICE)
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
|
||||
__hadd(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd,
|
||||
__hadd2(lhs, rhs), DEVICE)
|
||||
#else
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
|
||||
__float2bfloat16(__bfloat162float(lhs) +
|
||||
__bfloat162float(rhs)),
|
||||
DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, BinaryOpType::kAdd,
|
||||
__floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs),
|
||||
__high2float(lhs) + __high2float(rhs)),
|
||||
DEVICE)
|
||||
#endif
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul,
|
||||
__hmul(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul,
|
||||
__hmul2(lhs, rhs), DEVICE)
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
|
||||
__hmul(lhs, rhs), DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul,
|
||||
__hmul2(lhs, rhs), DEVICE)
|
||||
#else
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
|
||||
__float2bfloat16(__bfloat162float(lhs) *
|
||||
__bfloat162float(rhs)),
|
||||
DEVICE)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, BinaryOpType::kMul,
|
||||
__floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs),
|
||||
__high2float(lhs) * __high2float(rhs)),
|
||||
DEVICE)
|
||||
#endif
|
||||
|
||||
#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION
|
||||
|
||||
} // namespace funcs
|
||||
} // namespace cuda
|
||||
|
|
|
@ -22,12 +22,12 @@ struct GetOpForReduceType;
|
|||
|
||||
template <typename T>
|
||||
struct GetOpForReduceType<T, ReduceType::kMax> {
|
||||
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kMax>;
|
||||
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kMax>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GetOpForReduceType<T, ReduceType::kSum> {
|
||||
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kAdd>;
|
||||
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kAdd>;
|
||||
};
|
||||
|
||||
#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
|
||||
|
|
|
@ -10,10 +10,15 @@
|
|||
|
||||
#include "block_reduce.h"
|
||||
#include "../common/micros.h"
|
||||
#include "utils/cuda_type_utils.h"
|
||||
#include "funcs/cast_functor.h"
|
||||
#include "funcs/op_functor.h"
|
||||
|
||||
using colossalAI::cuda::utils::block_reduce;
|
||||
using colossalAI::cuda::utils::ReduceType;
|
||||
using colossalAI::cuda::funcs::TypeConverter;
|
||||
using colossalAI::cuda::funcs::CastFunctor;
|
||||
using colossalAI::cuda::funcs::BinaryOpFunctor;
|
||||
using colossalAI::cuda::funcs::BinaryOpType;
|
||||
|
||||
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
|
||||
if (DATA_SIZE == 2) { \
|
||||
|
@ -53,6 +58,7 @@ __global__ void rms_layernorm_kernel(
|
|||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using scalar2_t = typename TypeConverter<scalar_t>::Type;
|
||||
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t;
|
||||
__shared__ float s_variance;
|
||||
|
||||
/*
|
||||
|
@ -72,12 +78,13 @@ __global__ void rms_layernorm_kernel(
|
|||
float variance = 0.0f;
|
||||
int row_offset = blockIdx.x * hidden_size / 2;
|
||||
|
||||
|
||||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
x_local[cnt] = input_ptr[id];
|
||||
float v1 = cuda_cast<float>(x_local[cnt].x);
|
||||
float v2 = cuda_cast<float>(x_local[cnt].y);
|
||||
float v1 = CastFunctor<scalar_t,float>()(x_local[cnt].x);
|
||||
float v2 = CastFunctor<scalar_t,float>()(x_local[cnt].y);
|
||||
variance += v1 * v1 + v2 * v2;
|
||||
}
|
||||
block_reduce<float, ReduceType::kSum,1>(&variance);
|
||||
|
@ -86,11 +93,11 @@ __global__ void rms_layernorm_kernel(
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
|
||||
scalar2_t s_variance_2 = CastFunctor<float,scalar2_t>()(s_variance);
|
||||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
|
||||
out_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -137,6 +144,9 @@ __global__ void fused_add_rms_layernorm_kernel(
|
|||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using scalar2_t = typename TypeConverter<scalar_t>::Type;
|
||||
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kAdd> add_scalar2t;
|
||||
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t;
|
||||
|
||||
__shared__ float s_variance;
|
||||
scalar2_t x_local[4];
|
||||
|
||||
|
@ -151,9 +161,9 @@ __global__ void fused_add_rms_layernorm_kernel(
|
|||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
x_local[cnt] = input_ptr[id];
|
||||
x_local[cnt] = add(x_local[cnt], residual_ptr[id]);
|
||||
float v1 = cuda_cast<float>(x_local[cnt].x);
|
||||
float v2 = cuda_cast<float>(x_local[cnt].y);
|
||||
x_local[cnt] = add_scalar2t(x_local[cnt], residual_ptr[id]);
|
||||
float v1 = CastFunctor<scalar_t,float>()(x_local[cnt].x);
|
||||
float v2 = CastFunctor<scalar_t,float>()(x_local[cnt].y);
|
||||
variance += v1 * v1 + v2 * v2;
|
||||
residual_ptr[id] = x_local[cnt];
|
||||
}
|
||||
|
@ -163,11 +173,12 @@ __global__ void fused_add_rms_layernorm_kernel(
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
|
||||
scalar2_t s_variance_2 = CastFunctor<float, scalar2_t>()(s_variance);
|
||||
|
||||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
|
||||
input_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,122 +0,0 @@
|
|||
/*
|
||||
* This code from NVIDIA FasterTransformer:
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/utils/cuda_type_utils.cuh
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T add(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half2 add(half2 a, half2 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half add(half a, half b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
return bf16hadd2(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
return bf16hadd(a, b);
|
||||
}
|
||||
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T mul(T a, T b, T c) {
|
||||
return a * b * c;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half2 mul(half2 a, half2 b, half2 c) {
|
||||
return __hmul2(__hmul2(a, b), c);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b,
|
||||
__nv_bfloat16 c) {
|
||||
return bf16hmul(a, b, c);
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||
__nv_bfloat162 c) {
|
||||
return bf16hmul2(a, b, c);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, int2>(int2 val) {
|
||||
return make_float2(val.x, val.y);
|
||||
}
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, float>(float val) {
|
||||
return make_float2(val, val);
|
||||
}
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, half2>(half2 val) {
|
||||
return __half22float2(val);
|
||||
}
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, float2>(float2 val) {
|
||||
return __float22half2_rn(val);
|
||||
}
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, float>(float val) {
|
||||
return __float2half2_rn(val);
|
||||
}
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, half>(half val) {
|
||||
return __half2half2(val);
|
||||
}
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, half>(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = at::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<at::Half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = at::BFloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<at::BFloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
#endif // ENABLE_BF16
|
|
@ -12,3 +12,7 @@
|
|||
throw std::runtime_error(cudaGetErrorString(status)); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define HOST __host__
|
||||
#define DEVICE __device__
|
||||
#define HOSTDEVICE __host__ __device__
|
||||
|
|
Loading…
Reference in New Issue