2021-12-21 04:19:52 +00:00
|
|
|
#include "block_reduce.h"
|
|
|
|
#include "cuda_util.h"
|
|
|
|
#include "kernels.h"
|
|
|
|
#include "ls_cub.cuh"
|
|
|
|
|
|
|
|
ls::cub::CachingDeviceAllocator g_allocator(true);
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__global__ void ls_cross_entropy_fw_kernel(
|
|
|
|
const T *__restrict__ inputs, const int *__restrict__ targets,
|
|
|
|
float *__restrict__ outputs, float *__restrict__ nll_loss_outputs,
|
|
|
|
const int padding_idx, const float epsilon, const int vocab_size) {
|
|
|
|
/* step1: compute each thread's max_logit and sum_exp_logit, store in
|
|
|
|
* max_input, sum_exp_logit */
|
|
|
|
const int block_start = blockIdx.x * vocab_size;
|
|
|
|
const int left_idx = block_start + threadIdx.x;
|
|
|
|
const int right_idx = (blockIdx.x + 1) * vocab_size;
|
|
|
|
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
|
2022-05-13 14:00:05 +00:00
|
|
|
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
|
2021-12-21 04:19:52 +00:00
|
|
|
int target_tid = targets[blockIdx.x];
|
|
|
|
|
|
|
|
if (target_tid == padding_idx) {
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
nll_loss_outputs[blockIdx.x] = 0.f;
|
|
|
|
outputs[blockIdx.x] = 0.f;
|
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
|
|
|
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
|
|
|
|
}
|
|
|
|
blockReduce<ReduceType::kMax, 1>(max_input);
|
|
|
|
__shared__ float s_max_input;
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
s_max_input = max_input[0];
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
|
|
|
float logit = static_cast<float>(inputs[i]) - s_max_input;
|
|
|
|
sum_logits[0] += logit;
|
|
|
|
sum_logits[1] += expf(logit);
|
|
|
|
}
|
|
|
|
|
|
|
|
blockReduce<ReduceType::kSum, 2>(sum_logits);
|
|
|
|
__shared__ float s_sum_logit;
|
|
|
|
__shared__ float s_sum_exp;
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
s_sum_logit = sum_logits[0];
|
|
|
|
s_sum_exp = sum_logits[1];
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
float eps_i = epsilon / (vocab_size - 1);
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
// neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max)
|
|
|
|
float nll_loss = logf(s_sum_exp) -
|
|
|
|
static_cast<float>(inputs[block_start + target_tid]) +
|
|
|
|
s_max_input;
|
|
|
|
nll_loss_outputs[blockIdx.x] = nll_loss;
|
|
|
|
float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit;
|
|
|
|
outputs[blockIdx.x] =
|
|
|
|
(1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
__global__ void ls_cross_entropy_bw_kernel(
|
|
|
|
const float *__restrict__ grad_outputs, const T *__restrict__ inputs,
|
|
|
|
const int *__restrict__ targets, T *__restrict__ grad_inputs,
|
|
|
|
const int padding_idx, const float epsilon, const int vocab_size) {
|
|
|
|
/* step1: compute each thread's max_logit and sum_exp_logit, store in
|
|
|
|
* max_input, sum_exp_logit */
|
|
|
|
const int block_start = blockIdx.x * vocab_size;
|
|
|
|
const int left_idx = block_start + threadIdx.x;
|
|
|
|
const int right_idx = (blockIdx.x + 1) * vocab_size;
|
|
|
|
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
|
|
|
|
float sum_logits[1] = {0.f};
|
|
|
|
const float grad_out = static_cast<float>(grad_outputs[0]);
|
|
|
|
int target_tid = targets[blockIdx.x];
|
|
|
|
|
|
|
|
if (target_tid == padding_idx) {
|
|
|
|
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
|
|
|
grad_inputs[i] = 0.f;
|
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
|
|
|
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
|
|
|
|
}
|
|
|
|
blockReduce<ReduceType::kMax, 1>(max_input);
|
|
|
|
__shared__ float s_max_input;
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
s_max_input = max_input[0];
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
|
|
|
float logit = static_cast<float>(inputs[i]) - s_max_input;
|
|
|
|
sum_logits[0] += expf(logit);
|
|
|
|
}
|
|
|
|
|
|
|
|
blockReduce<ReduceType::kSum, 1>(sum_logits);
|
|
|
|
__shared__ float s_sum_exp;
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
s_sum_exp = sum_logits[0];
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
float eps_i = epsilon / (vocab_size - 1);
|
|
|
|
float nll_weight = 1.0 - epsilon - eps_i;
|
|
|
|
|
|
|
|
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
|
|
|
float prob = expf(static_cast<float>(inputs[i]) - s_max_input) / s_sum_exp;
|
|
|
|
float grad = 0;
|
|
|
|
grad += (vocab_size * prob - 1) * eps_i;
|
|
|
|
grad += prob * nll_weight;
|
|
|
|
if ((i - block_start) == target_tid) {
|
|
|
|
grad -= nll_weight;
|
|
|
|
}
|
|
|
|
grad_inputs[i] = grad_out * grad;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
|
|
|
|
float *outputs_ptr, float *nll_loss_ptr,
|
|
|
|
float *loss_buffer, const int padding_idx,
|
|
|
|
const float epsilon, const int batch_size,
|
|
|
|
const int seq_len, const int vocab_size,
|
|
|
|
cudaStream_t stream) {
|
|
|
|
int grid_dim = batch_size * seq_len;
|
|
|
|
float *nll_loss_buffer = loss_buffer + grid_dim;
|
|
|
|
ls_cross_entropy_fw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
|
|
|
|
inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx,
|
|
|
|
epsilon, vocab_size);
|
|
|
|
|
|
|
|
int num_items = grid_dim;
|
|
|
|
void *d_temp_storage = NULL;
|
|
|
|
size_t temp_storage_bytes = 0;
|
|
|
|
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
|
|
|
|
loss_buffer, outputs_ptr,
|
|
|
|
num_items, stream));
|
|
|
|
CHECK_GPU_ERROR(
|
|
|
|
g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes));
|
|
|
|
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
|
|
|
|
loss_buffer, outputs_ptr,
|
|
|
|
num_items, stream));
|
|
|
|
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
|
|
|
|
nll_loss_buffer, nll_loss_ptr,
|
|
|
|
num_items, stream));
|
|
|
|
CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage));
|
|
|
|
}
|
|
|
|
|
|
|
|
template void launch_cross_entropy_fw<float>(
|
|
|
|
const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
|
|
|
|
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
|
|
|
|
const float epsilon, const int batch_size, const int seq_len,
|
|
|
|
const int vocab_size, cudaStream_t stream);
|
|
|
|
|
|
|
|
template void launch_cross_entropy_fw<__half>(
|
|
|
|
const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
|
|
|
|
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
|
|
|
|
const float epsilon, const int batch_size, const int seq_len,
|
|
|
|
const int vocab_size, cudaStream_t stream);
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
|
|
|
|
const int *targets_ptr, T *grad_inputs_ptr,
|
|
|
|
const int padding_idx, const float epsilon,
|
|
|
|
const int batch_size, const int seq_len,
|
|
|
|
const int vocab_size, cudaStream_t stream) {
|
|
|
|
int grid_dim = batch_size * seq_len;
|
|
|
|
ls_cross_entropy_bw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
|
|
|
|
grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx,
|
|
|
|
epsilon, vocab_size);
|
|
|
|
}
|
|
|
|
|
|
|
|
template void launch_cross_entropy_bw<float>(
|
|
|
|
const float *grad_outputs_ptr, const float *inputs_ptr,
|
|
|
|
const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx,
|
|
|
|
const float epsilon, const int batch_size, const int seq_len,
|
|
|
|
const int vocab_size, cudaStream_t stream);
|
|
|
|
|
|
|
|
template void launch_cross_entropy_bw<__half>(
|
|
|
|
const float *grad_outputs_ptr, const __half *inputs_ptr,
|
|
|
|
const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx,
|
|
|
|
const float epsilon, const int batch_size, const int seq_len,
|
|
|
|
const int vocab_size, cudaStream_t stream);
|