refactor code

pull/5435/head
xs_courtesy 2024-03-08 15:41:14 +08:00
parent 01d289d8e5
commit 5eb5ff1464
16 changed files with 117 additions and 118 deletions

View File

@ -9,7 +9,15 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.h" #ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \ switch (TYPE) { \
@ -214,90 +222,3 @@
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
"'"); \ "'"); \
} }
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(
T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final =
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}

View File

@ -2,10 +2,10 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "../type_shim.h" #include "micros.h"
namespace infer { namespace colossalAI {
namespace dtype { namespace common {
template <typename T> template <typename T>
class MPTypeTrait { class MPTypeTrait {
@ -31,5 +31,5 @@ class MPTypeTrait<at::BFloat16> {
using Type = float; using Type = float;
}; };
} // namespace dtype } // namespace common
} // namespace infer } // namespace colossalAI

View File

@ -2,13 +2,13 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <stdio.h> #include <stdio.h>
#include "type_shim.h" #include "../common/micros.h"
#include "include/mp_type_traits.h" #include "../common/mp_type_traits.h"
template<typename T> template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) { __device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x) // x * sigmoid(x)
using MT = typename infer::dtype::MPTypeTrait<T>::Type; using MT = typename colossalAI::common::MPTypeTrait<T>::Type;
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x)))); return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
} }
@ -17,7 +17,7 @@ __global__ void act_and_mul_kernel(
const scalar_t* __restrict__ ins_data, const scalar_t* __restrict__ ins_data,
scalar_t* __restrict__ outs_data, scalar_t* __restrict__ outs_data,
const int64_t numel) { const int64_t numel) {
using MT = typename infer::dtype::MPTypeTrait<scalar_t>::Type; using MT = typename colossalAI::common::MPTypeTrait<scalar_t>::Type;
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x); int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
const int64_t grid_size = blockDim.x * gridDim.x; const int64_t grid_size = blockDim.x * gridDim.x;

View File

@ -1,10 +0,0 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif

View File

@ -2,7 +2,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <stdio.h> #include <stdio.h>
#include "type_shim.h" #include "../common/micros.h"
template<typename scalar_t> template<typename scalar_t>
__global__ void decode_kv_cache_memcpy_kernel( __global__ void decode_kv_cache_memcpy_kernel(

View File

@ -310,3 +310,90 @@ __inline__ __device__ void blockReduce<ReduceType::kMax, 4>(float *pval) {
} }
warpReduce<ReduceType::kMax, num>(pval); warpReduce<ReduceType::kMax, num>(pval);
} }
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(
T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final =
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}

View File

@ -7,7 +7,7 @@
#include <cassert> #include <cassert>
#include <vector> #include <vector>
#include "compat.h" #include "../common/micros.h"
namespace { namespace {

View File

@ -9,7 +9,7 @@
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/DeviceUtils.cuh" #include "ATen/cuda/DeviceUtils.cuh"
#include "type_shim.h" #include "../common/micros.h"
template <typename U> template <typename U>
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { __device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {

View File

@ -15,7 +15,7 @@
#include <assert.h> #include <assert.h>
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h" #include "../common/micros.h"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4

View File

@ -12,7 +12,7 @@
#include <assert.h> #include <assert.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "compat.h" #include "../common/micros.h"
// #include <iostream> // #include <iostream>

View File

@ -11,7 +11,8 @@
#include <assert.h> #include <assert.h>
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h" #include "../common/micros.h"
#include "include/block_reduce.h"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4

View File

@ -10,7 +10,7 @@
#include <assert.h> #include <assert.h>
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h" #include "../common/micros.h"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4

View File

@ -10,7 +10,7 @@
#include <sstream> #include <sstream>
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h" #include "../common/micros.h"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4

View File

@ -7,7 +7,7 @@
#include <assert.h> #include <assert.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "compat.h" #include "../common/micros.h"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512

View File

@ -10,7 +10,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_masked_softmax.h" #include "scaled_masked_softmax.h"
#include "type_shim.h" #include "../common/micros.h"
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {

View File

@ -10,7 +10,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h" #include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h" #include "../common/micros.h"
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {