#include "block_reduce.h" #include "cuda_util.h" #include "kernels.h" #include "ls_cub.cuh" ls::cub::CachingDeviceAllocator g_allocator(true); template __global__ void ls_cross_entropy_fw_kernel( const T *__restrict__ inputs, const int *__restrict__ targets, float *__restrict__ outputs, float *__restrict__ nll_loss_outputs, const int padding_idx, const float epsilon, const int vocab_size) { /* step1: compute each thread's max_logit and sum_exp_logit, store in * max_input, sum_exp_logit */ const int block_start = blockIdx.x * vocab_size; const int left_idx = block_start + threadIdx.x; const int right_idx = (blockIdx.x + 1) * vocab_size; float max_input[1] = {REDUCE_FLOAT_INF_NEG}; float sum_logits[2] = {0.f, 0.f}; // logit and logit exp int target_tid = targets[blockIdx.x]; if (target_tid == padding_idx) { if (threadIdx.x == 0) { nll_loss_outputs[blockIdx.x] = 0.f; outputs[blockIdx.x] = 0.f; } return; } for (int i = left_idx; i < right_idx; i += blockDim.x) { max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); } blockReduce(max_input); __shared__ float s_max_input; if (threadIdx.x == 0) { s_max_input = max_input[0]; } __syncthreads(); for (int i = left_idx; i < right_idx; i += blockDim.x) { float logit = static_cast(inputs[i]) - s_max_input; sum_logits[0] += logit; sum_logits[1] += expf(logit); } blockReduce(sum_logits); __shared__ float s_sum_logit; __shared__ float s_sum_exp; if (threadIdx.x == 0) { s_sum_logit = sum_logits[0]; s_sum_exp = sum_logits[1]; } __syncthreads(); float eps_i = epsilon / (vocab_size - 1); if (threadIdx.x == 0) { // neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max) float nll_loss = logf(s_sum_exp) - static_cast(inputs[block_start + target_tid]) + s_max_input; nll_loss_outputs[blockIdx.x] = nll_loss; float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit; outputs[blockIdx.x] = (1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss; } } template __global__ void ls_cross_entropy_bw_kernel( const float *__restrict__ grad_outputs, const T *__restrict__ inputs, const int *__restrict__ targets, T *__restrict__ grad_inputs, const int padding_idx, const float epsilon, const int vocab_size) { /* step1: compute each thread's max_logit and sum_exp_logit, store in * max_input, sum_exp_logit */ const int block_start = blockIdx.x * vocab_size; const int left_idx = block_start + threadIdx.x; const int right_idx = (blockIdx.x + 1) * vocab_size; float max_input[1] = {REDUCE_FLOAT_INF_NEG}; float sum_logits[1] = {0.f}; const float grad_out = static_cast(grad_outputs[0]); int target_tid = targets[blockIdx.x]; if (target_tid == padding_idx) { for (int i = left_idx; i < right_idx; i += blockDim.x) { grad_inputs[i] = 0.f; } return; } for (int i = left_idx; i < right_idx; i += blockDim.x) { max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); } blockReduce(max_input); __shared__ float s_max_input; if (threadIdx.x == 0) { s_max_input = max_input[0]; } __syncthreads(); for (int i = left_idx; i < right_idx; i += blockDim.x) { float logit = static_cast(inputs[i]) - s_max_input; sum_logits[0] += expf(logit); } blockReduce(sum_logits); __shared__ float s_sum_exp; if (threadIdx.x == 0) { s_sum_exp = sum_logits[0]; } __syncthreads(); float eps_i = epsilon / (vocab_size - 1); float nll_weight = 1.0 - epsilon - eps_i; for (int i = left_idx; i < right_idx; i += blockDim.x) { float prob = expf(static_cast(inputs[i]) - s_max_input) / s_sum_exp; float grad = 0; grad += (vocab_size * prob - 1) * eps_i; grad += prob * nll_weight; if ((i - block_start) == target_tid) { grad -= nll_weight; } grad_inputs[i] = grad_out * grad; } } template void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr, float *nll_loss_ptr, float *loss_buffer, const int padding_idx, const float epsilon, const int batch_size, const int seq_len, const int vocab_size, cudaStream_t stream) { int grid_dim = batch_size * seq_len; float *nll_loss_buffer = loss_buffer + grid_dim; ls_cross_entropy_fw_kernel<<>>( inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx, epsilon, vocab_size); int num_items = grid_dim; void *d_temp_storage = NULL; size_t temp_storage_bytes = 0; CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, loss_buffer, outputs_ptr, num_items, stream)); CHECK_GPU_ERROR( g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes)); CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, loss_buffer, outputs_ptr, num_items, stream)); CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, nll_loss_buffer, nll_loss_ptr, num_items, stream)); CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage)); } template void launch_cross_entropy_fw( const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr, float *nll_loss_ptr, float *loss_buffer, const int padding_idx, const float epsilon, const int batch_size, const int seq_len, const int vocab_size, cudaStream_t stream); template void launch_cross_entropy_fw<__half>( const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr, float *nll_loss_ptr, float *loss_buffer, const int padding_idx, const float epsilon, const int batch_size, const int seq_len, const int vocab_size, cudaStream_t stream); template void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, const int *targets_ptr, T *grad_inputs_ptr, const int padding_idx, const float epsilon, const int batch_size, const int seq_len, const int vocab_size, cudaStream_t stream) { int grid_dim = batch_size * seq_len; ls_cross_entropy_bw_kernel<<>>( grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx, epsilon, vocab_size); } template void launch_cross_entropy_bw( const float *grad_outputs_ptr, const float *inputs_ptr, const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx, const float epsilon, const int batch_size, const int seq_len, const int vocab_size, cudaStream_t stream); template void launch_cross_entropy_bw<__half>( const float *grad_outputs_ptr, const __half *inputs_ptr, const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx, const float epsilon, const int batch_size, const int seq_len, const int vocab_size, cudaStream_t stream);