|
|
|
@ -23,8 +23,8 @@ using colossalAI::funcs::UnaryOpFunctor;
|
|
|
|
|
using colossalAI::funcs::UnaryOpType; |
|
|
|
|
using colossalAI::funcs::warp_reduce; |
|
|
|
|
using colossalAI::funcs::ReduceType; |
|
|
|
|
using colossalAI::cuda::utils::copy_vector; |
|
|
|
|
using colossalAI::cuda::utils::copy_zero_vector; |
|
|
|
|
using colossalAI::cuda::utils::copy; |
|
|
|
|
using colossalAI::cuda::utils::copy_zero; |
|
|
|
|
|
|
|
|
|
/* |
|
|
|
|
* Extended softmax (from native aten pytorch) with following additional |
|
|
|
@ -75,8 +75,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
|
|
|
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; |
|
|
|
|
|
|
|
|
|
if (element_index < batch_element_count) { |
|
|
|
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
temp_data, src + i * element_count * stride + it * WARP_SIZE); |
|
|
|
|
copy<input_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
src + i * element_count * stride + it * WARP_SIZE, temp_data); |
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { |
|
|
|
@ -140,10 +140,10 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
|
|
|
|
out[element] = 0; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
copy_vector<output_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
dst + i * element_count * stride + it * WARP_SIZE, out); |
|
|
|
|
copy<output_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
out, dst + i * element_count * stride + it * WARP_SIZE); |
|
|
|
|
} else if (element_index < element_count) { |
|
|
|
|
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
copy_zero<output_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
dst + i * element_count * stride + it * WARP_SIZE); |
|
|
|
|
} else { |
|
|
|
|
break; |
|
|
|
@ -199,10 +199,10 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
|
|
|
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { |
|
|
|
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; |
|
|
|
|
if (element_index < batch_element_count) { |
|
|
|
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
temp_grad, grad + i * element_count * stride + it * WARP_SIZE); |
|
|
|
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
temp_output, output + i * element_count * stride + it * WARP_SIZE); |
|
|
|
|
copy<input_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
grad + i * element_count * stride + it * WARP_SIZE, temp_grad); |
|
|
|
|
copy<input_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
output + i * element_count * stride + it * WARP_SIZE, temp_output); |
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { |
|
|
|
@ -248,8 +248,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
|
|
|
|
(output_t)(scale * (grad_reg[i][it + element] - |
|
|
|
|
output_reg[i][it + element] * sum[i])); |
|
|
|
|
} |
|
|
|
|
copy_vector<output_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
gradInput + i * element_count * stride + it * WARP_SIZE, out); |
|
|
|
|
copy<output_t, ELEMENTS_PER_LDG_STG>( |
|
|
|
|
out, gradInput + i * element_count * stride + it * WARP_SIZE); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|