Browse Source

refactor code

pull/5435/head
xs_courtesy 9 months ago
parent
commit
5eb5ff1464
  1. 97
      extensions/csrc/common/micros.h
  2. 10
      extensions/csrc/common/mp_type_traits.h
  3. 8
      extensions/csrc/cuda/activation_kernel.cu
  4. 10
      extensions/csrc/cuda/compat.h
  5. 2
      extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu
  6. 87
      extensions/csrc/cuda/include/block_reduce.h
  7. 2
      extensions/csrc/cuda/layer_norm_cuda.cpp
  8. 2
      extensions/csrc/cuda/layer_norm_cuda_kernel.cu
  9. 2
      extensions/csrc/cuda/multi_tensor_adam.cu
  10. 2
      extensions/csrc/cuda/multi_tensor_apply.cuh
  11. 3
      extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu
  12. 2
      extensions/csrc/cuda/multi_tensor_lamb.cu
  13. 2
      extensions/csrc/cuda/multi_tensor_scale_kernel.cu
  14. 2
      extensions/csrc/cuda/multi_tensor_sgd_kernel.cu
  15. 2
      extensions/csrc/cuda/scaled_masked_softmax_cuda.cu
  16. 2
      extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu

97
extensions/csrc/cuda/type_shim.h → extensions/csrc/common/micros.h

@ -9,7 +9,15 @@
#include <ATen/ATen.h>
#include "compat.h"
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
@ -214,90 +222,3 @@
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
"'"); \
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(
T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final =
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}

10
extensions/csrc/cuda/include/mp_type_traits.h → extensions/csrc/common/mp_type_traits.h

@ -2,10 +2,10 @@
#include <ATen/ATen.h>
#include "../type_shim.h"
#include "micros.h"
namespace infer {
namespace dtype {
namespace colossalAI {
namespace common {
template <typename T>
class MPTypeTrait {
@ -31,5 +31,5 @@ class MPTypeTrait<at::BFloat16> {
using Type = float;
};
} // namespace dtype
} // namespace infer
} // namespace common
} // namespace colossalAI

8
extensions/csrc/cuda/activation_kernel.cu

@ -2,13 +2,13 @@
#include <torch/extension.h>
#include <stdio.h>
#include "type_shim.h"
#include "include/mp_type_traits.h"
#include "../common/micros.h"
#include "../common/mp_type_traits.h"
template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
using MT = typename infer::dtype::MPTypeTrait<T>::Type;
using MT = typename colossalAI::common::MPTypeTrait<T>::Type;
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
}
@ -17,7 +17,7 @@ __global__ void act_and_mul_kernel(
const scalar_t* __restrict__ ins_data,
scalar_t* __restrict__ outs_data,
const int64_t numel) {
using MT = typename infer::dtype::MPTypeTrait<scalar_t>::Type;
using MT = typename colossalAI::common::MPTypeTrait<scalar_t>::Type;
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
const int64_t grid_size = blockDim.x * gridDim.x;

10
extensions/csrc/cuda/compat.h

@ -1,10 +0,0 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif

2
extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu

@ -2,7 +2,7 @@
#include <torch/extension.h>
#include <stdio.h>
#include "type_shim.h"
#include "../common/micros.h"
template<typename scalar_t>
__global__ void decode_kv_cache_memcpy_kernel(

87
extensions/csrc/cuda/include/block_reduce.h

@ -310,3 +310,90 @@ __inline__ __device__ void blockReduce<ReduceType::kMax, 4>(float *pval) {
}
warpReduce<ReduceType::kMax, num>(pval);
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(
T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final =
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}

2
extensions/csrc/cuda/layer_norm_cuda.cpp

@ -7,7 +7,7 @@
#include <cassert>
#include <vector>
#include "compat.h"
#include "../common/micros.h"
namespace {

2
extensions/csrc/cuda/layer_norm_cuda_kernel.cu

@ -9,7 +9,7 @@
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/DeviceUtils.cuh"
#include "type_shim.h"
#include "../common/micros.h"
template <typename U>
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {

2
extensions/csrc/cuda/multi_tensor_adam.cu

@ -15,7 +15,7 @@
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#include "../common/micros.h"
#define BLOCK_SIZE 512
#define ILP 4

2
extensions/csrc/cuda/multi_tensor_apply.cuh

@ -12,7 +12,7 @@
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#include "compat.h"
#include "../common/micros.h"
// #include <iostream>

3
extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu

@ -11,7 +11,8 @@
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#include "../common/micros.h"
#include "include/block_reduce.h"
#define BLOCK_SIZE 512
#define ILP 4

2
extensions/csrc/cuda/multi_tensor_lamb.cu

@ -10,7 +10,7 @@
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#include "../common/micros.h"
#define BLOCK_SIZE 512
#define ILP 4

2
extensions/csrc/cuda/multi_tensor_scale_kernel.cu

@ -10,7 +10,7 @@
#include <sstream>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#include "../common/micros.h"
#define BLOCK_SIZE 512
#define ILP 4

2
extensions/csrc/cuda/multi_tensor_sgd_kernel.cu

@ -7,7 +7,7 @@
#include <assert.h>
#include <cuda_runtime.h>
#include "compat.h"
#include "../common/micros.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512

2
extensions/csrc/cuda/scaled_masked_softmax_cuda.cu

@ -10,7 +10,7 @@
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
#include "../common/micros.h"
namespace multihead_attn {
namespace fused_softmax {

2
extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu

@ -10,7 +10,7 @@
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
#include "../common/micros.h"
namespace multihead_attn {
namespace fused_softmax {

Loading…
Cancel
Save