mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Delete duplicated copy_vector (#5716)
parent
7806842f2d
commit
121d7ad629
|
@ -5,7 +5,6 @@
|
||||||
#include "funcs/cast_functor.h"
|
#include "funcs/cast_functor.h"
|
||||||
#include "common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
|
||||||
using colossalAI::cuda::utils::get_vec_size;
|
using colossalAI::cuda::utils::get_vec_size;
|
||||||
using colossalAI::cuda::utils::copy;
|
using colossalAI::cuda::utils::copy;
|
||||||
using colossalAI::funcs::CastFunctor;
|
using colossalAI::funcs::CastFunctor;
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
#include "funcs/cast_functor.h"
|
#include "funcs/cast_functor.h"
|
||||||
#include "funcs/binary_functor.h"
|
#include "funcs/binary_functor.h"
|
||||||
|
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
|
||||||
using colossalAI::cuda::utils::get_vec_size;
|
using colossalAI::cuda::utils::get_vec_size;
|
||||||
using colossalAI::cuda::utils::copy;
|
using colossalAI::cuda::utils::copy;
|
||||||
using colossalAI::funcs::CastFunctor;
|
using colossalAI::funcs::CastFunctor;
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
#include "utils/vec_copy.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
using colossalAI::cuda::utils::copy;
|
||||||
using colossalAI::cuda::utils::get_vec_size;
|
using colossalAI::cuda::utils::get_vec_size;
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,8 +23,8 @@ __device__ void apply_cos_and_sin_memcopy(
|
||||||
int begin_id = threadIdx.x * VecSize;
|
int begin_id = threadIdx.x * VecSize;
|
||||||
|
|
||||||
for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){
|
for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){
|
||||||
copy_vector<scalar_t, VecSize>(cos + dest_offset_id + begin_id, cos_cache_ptr + src_offset_id + begin_id);
|
copy<scalar_t, VecSize>(cos_cache_ptr + src_offset_id + begin_id, cos + dest_offset_id + begin_id);
|
||||||
copy_vector<scalar_t, VecSize>(sin + dest_offset_id + begin_id, sin_cache_ptr + src_offset_id + begin_id);
|
copy<scalar_t, VecSize>(sin_cache_ptr + src_offset_id + begin_id, sin + dest_offset_id + begin_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!Aligned) {
|
if (!Aligned) {
|
||||||
|
|
|
@ -23,7 +23,7 @@ using colossalAI::funcs::UnaryOpFunctor;
|
||||||
using colossalAI::funcs::UnaryOpType;
|
using colossalAI::funcs::UnaryOpType;
|
||||||
using colossalAI::funcs::warp_reduce;
|
using colossalAI::funcs::warp_reduce;
|
||||||
using colossalAI::funcs::ReduceType;
|
using colossalAI::funcs::ReduceType;
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
using colossalAI::cuda::utils::copy;
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -87,8 +87,8 @@ __global__ void scaled_masked_softmax_warp_forward(
|
||||||
|
|
||||||
if (element_index < batch_element_count) {
|
if (element_index < batch_element_count) {
|
||||||
int itr_idx = i * element_count + it * WARP_SIZE;
|
int itr_idx = i * element_count + it * WARP_SIZE;
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
|
copy<input_t, ELEMENTS_PER_LDG_STG>(src + itr_idx, temp_data);
|
||||||
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
|
copy<uint8_t, ELEMENTS_PER_LDG_STG>(mask + itr_idx, temp_mask);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
@ -144,8 +144,8 @@ __global__ void scaled_masked_softmax_warp_forward(
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
out[element] = elements[i][it + element] / sum[i];
|
out[element] = elements[i][it + element] / sum[i];
|
||||||
}
|
}
|
||||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
copy<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
dst + i * element_count + it * WARP_SIZE, out);
|
out, dst + i * element_count + it * WARP_SIZE);
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -200,10 +200,10 @@ __global__ void scaled_masked_softmax_warp_backward(
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
if (element_index < batch_element_count) {
|
if (element_index < batch_element_count) {
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
copy<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
temp_grad, grad + i * element_count + it * WARP_SIZE);
|
grad + i * element_count + it * WARP_SIZE, temp_grad);
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
copy<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
temp_output, output + i * element_count + it * WARP_SIZE);
|
output + i * element_count + it * WARP_SIZE, temp_output);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
@ -245,8 +245,8 @@ __global__ void scaled_masked_softmax_warp_backward(
|
||||||
(output_t)(scale * (grad_reg[i][it + element] -
|
(output_t)(scale * (grad_reg[i][it + element] -
|
||||||
output_reg[i][it + element] * sum[i]));
|
output_reg[i][it + element] * sum[i]));
|
||||||
}
|
}
|
||||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
copy<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
gradInput + i * element_count + it * WARP_SIZE, out);
|
out, gradInput + i * element_count + it * WARP_SIZE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,8 +23,8 @@ using colossalAI::funcs::UnaryOpFunctor;
|
||||||
using colossalAI::funcs::UnaryOpType;
|
using colossalAI::funcs::UnaryOpType;
|
||||||
using colossalAI::funcs::warp_reduce;
|
using colossalAI::funcs::warp_reduce;
|
||||||
using colossalAI::funcs::ReduceType;
|
using colossalAI::funcs::ReduceType;
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
using colossalAI::cuda::utils::copy;
|
||||||
using colossalAI::cuda::utils::copy_zero_vector;
|
using colossalAI::cuda::utils::copy_zero;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Extended softmax (from native aten pytorch) with following additional
|
* Extended softmax (from native aten pytorch) with following additional
|
||||||
|
@ -75,8 +75,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
|
|
||||||
if (element_index < batch_element_count) {
|
if (element_index < batch_element_count) {
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
copy<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
temp_data, src + i * element_count * stride + it * WARP_SIZE);
|
src + i * element_count * stride + it * WARP_SIZE, temp_data);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
@ -140,10 +140,10 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
||||||
out[element] = 0;
|
out[element] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
copy<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
dst + i * element_count * stride + it * WARP_SIZE, out);
|
out, dst + i * element_count * stride + it * WARP_SIZE);
|
||||||
} else if (element_index < element_count) {
|
} else if (element_index < element_count) {
|
||||||
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
copy_zero<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
dst + i * element_count * stride + it * WARP_SIZE);
|
dst + i * element_count * stride + it * WARP_SIZE);
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
|
@ -199,10 +199,10 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
if (element_index < batch_element_count) {
|
if (element_index < batch_element_count) {
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
copy<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
|
grad + i * element_count * stride + it * WARP_SIZE, temp_grad);
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
copy<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
temp_output, output + i * element_count * stride + it * WARP_SIZE);
|
output + i * element_count * stride + it * WARP_SIZE, temp_output);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
@ -248,8 +248,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
||||||
(output_t)(scale * (grad_reg[i][it + element] -
|
(output_t)(scale * (grad_reg[i][it + element] -
|
||||||
output_reg[i][it + element] * sum[i]));
|
output_reg[i][it + element] * sum[i]));
|
||||||
}
|
}
|
||||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
copy<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
gradInput + i * element_count * stride + it * WARP_SIZE, out);
|
out, gradInput + i * element_count * stride + it * WARP_SIZE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,25 +8,8 @@ namespace colossalAI {
|
||||||
namespace cuda {
|
namespace cuda {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
|
|
||||||
// Note(LiuYang): Depreciated
|
|
||||||
template <typename T, int VecSize>
|
template <typename T, int VecSize>
|
||||||
__device__ __inline__ void copy_vector(T *dst, const T *src) {
|
__device__ __inline__ void copy_zero(T *dst) {
|
||||||
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
|
|
||||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
|
|
||||||
// Since the maximum memory alignment length is 128 bits, we choose float4
|
|
||||||
// here.
|
|
||||||
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<const float4 *>(src));
|
|
||||||
*(reinterpret_cast<float4 *>(dst + 4)) =
|
|
||||||
*(reinterpret_cast<const float4 *>(src + 4));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note(LiuYang): Depreciated
|
|
||||||
template <typename T, int VecSize>
|
|
||||||
__device__ __inline__ void copy_zero_vector(T *dst) {
|
|
||||||
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
|
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
|
||||||
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
|
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue