The writing style of tail processing and the logic related to macro definitions have been optimized. (#5519)

pull/5546/head
yuehuayingxueluo 2024-03-28 10:42:51 +08:00 committed by GitHub
parent e6496dd371
commit 934e31afb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 129 additions and 165 deletions

View File

@ -2,7 +2,7 @@ ROOT=$(realpath $(dirname $0))
echo $ROOT
PY_SCRIPT=${ROOT}/benchmark_llama.py
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
mode="colossalai"
mode=$1
mkdir -p logs

View File

@ -56,21 +56,14 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
switch (HIGH_PRECISION) { \
case false: { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
case true: { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
default: \
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
if (HIGH_PRECISION) { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
} else { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \

View File

@ -27,17 +27,11 @@ struct MPTypeTrait<at::BFloat16> {
using Type = float;
};
template <bool high_precision, typename scalar_t>
struct ScalarTypeTrait;
template <typename T>
struct ScalarTypeTrait<true, T> {
using Type = typename MPTypeTrait<T>::Type;
};
template <typename T>
struct ScalarTypeTrait<false, T> {
using Type = T;
template <bool high_precision, typename T>
struct ScalarTypeTrait {
using Type =
typename std::conditional<high_precision, typename MPTypeTrait<T>::Type,
T>::type;
};
} // namespace common

View File

@ -4,7 +4,7 @@
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
template<typename scalar_t, int VecSize>
template<typename scalar_t, bool Aligned, int VecSize>
__global__ void context_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
@ -55,17 +55,19 @@ __global__ void context_kv_cache_memcpy_kernel(
}
// tail process
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
if (!Aligned) {
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}
}
@ -93,76 +95,61 @@ void apply_context_kv_cache_memcpy(
int vec_size = get_vec_size<scalar_t>(key);
bool aligned = true;
if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
aligned = false;
}
int thread_nums = head_num * head_dim / vec_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(max_seq_len_in_batch, batch_size);
dim3 block(std::min(thread_nums, 512));
switch (vec_size) {
case 1:
context_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 2:
context_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
case 4:
context_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
cu_seqlens.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
batch_size,
block_table_stride,
key_stride,
value_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
do { \
context_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
cu_seqlens.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
head_num, \
head_dim, \
block_size, \
batch_size, \
block_table_stride, \
key_stride, \
value_stride \
); \
} while(0)
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \
do { \
switch (vec_size) { \
case 1: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \
break; \
case 2: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \
break; \
case 4: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \
break; \
default: \
AT_ERROR("Unsupported vectorized size ", vec_size); \
break; \
} \
} while(0)
if (aligned) {
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true);
}
else {
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false);
}
AT_CUDA_CHECK(cudaGetLastError());

View File

@ -4,7 +4,7 @@
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
template<typename scalar_t, int VecSize>
template<typename scalar_t, bool Aligned, int VecSize>
__global__ void decode_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
@ -45,17 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel(
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}
for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
if (!Aligned) {
for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}
}
@ -80,70 +82,58 @@ void apply_decode_kv_cache_memcpy(
int vec_size = get_vec_size<scalar_t>(key);
bool aligned = true;
if (head_dim % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
aligned = false;
}
int thread_nums = head_num * head_dim / vec_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512));
switch (vec_size) {
case 1:
decode_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
case 2:
decode_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
case 4:
decode_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
head_num,
head_dim,
block_size,
key_stride,
value_stride,
block_table_stride
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
do { \
decode_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
head_num, \
head_dim, \
block_size, \
key_stride, \
value_stride, \
block_table_stride \
); \
} while(0)
#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned, __vec_size) \
do { \
switch (__vec_size) { \
case 1: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \
break; \
case 2: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \
break; \
case 4: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \
break; \
default: \
AT_ERROR("Unsupported vectorized size ", __vec_size); \
break; \
} \
} while(0)
if (aligned) {
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true, vec_size);
}
else {
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false, vec_size);
}
AT_CUDA_CHECK(cudaGetLastError());