mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952)
parent
fa43bb216d
commit
d8d07b0e2b
|
@ -10,8 +10,9 @@
|
|||
#include "kernels.h"
|
||||
|
||||
template <typename T>
|
||||
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens, int max_seq_len,
|
||||
int hidden_size, int num_heads,
|
||||
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens,
|
||||
int max_seq_len, int hidden_size,
|
||||
int num_heads,
|
||||
float attn_prob_dropout_ratio,
|
||||
float hidden_output_dropout_ratio,
|
||||
bool pre_or_postLayerNorm)
|
||||
|
@ -22,18 +23,22 @@ MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens, in
|
|||
_heads(num_heads),
|
||||
_training(true),
|
||||
_pre_or_postLayerNorm(pre_or_postLayerNorm),
|
||||
_qkv_linear(typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
|
||||
_attn_out_linear(typename FeedForward<T>::Config(hidden_size, hidden_size)),
|
||||
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false), _max_batch_tokens),
|
||||
_qkv_linear(
|
||||
typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
|
||||
_attn_out_linear(
|
||||
typename FeedForward<T>::Config(hidden_size, hidden_size)),
|
||||
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false),
|
||||
_max_batch_tokens),
|
||||
_softmax(typename Softmax<T>::Config(num_heads)),
|
||||
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio),
|
||||
_max_batch_tokens * _heads * _max_seq_len),
|
||||
_attn_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio),
|
||||
_max_batch_tokens * _hidden_size),
|
||||
_attn_scores(typename StridedBatchGemm<T>::Config((T(1.0) / T(sqrt(_hidden_size / _heads))),
|
||||
T(0.0), CUBLAS_OP_T, CUBLAS_OP_N)),
|
||||
_attn_context(
|
||||
typename StridedBatchGemm<T>::Config(T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
|
||||
_attn_scores(typename StridedBatchGemm<T>::Config(
|
||||
(T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T,
|
||||
CUBLAS_OP_N)),
|
||||
_attn_context(typename StridedBatchGemm<T>::Config(
|
||||
T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
|
||||
assert(_hidden_size % _heads == 0);
|
||||
}
|
||||
|
||||
|
@ -43,43 +48,52 @@ MultiHeadAttention<T>::~MultiHeadAttention() {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr, const T *input_mask_ptr,
|
||||
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr,
|
||||
const T *input_mask_ptr,
|
||||
T *output_ptr, T *buffer) {
|
||||
T *q_tf_ptr = _qkv_ptr;
|
||||
T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
|
||||
T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens,
|
||||
_stream);
|
||||
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr,
|
||||
_batch_tokens, _stream);
|
||||
}
|
||||
const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
const T *gemmQKV_inp_ptr =
|
||||
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
|
||||
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, _cublasHandle);
|
||||
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer,
|
||||
_cublasHandle);
|
||||
|
||||
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr, _batch_size, _seq_len, 3,
|
||||
_heads / pg_size, _hidden_size / _heads, _stream);
|
||||
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr,
|
||||
_batch_size, _seq_len, 3, _heads / pg_size,
|
||||
_hidden_size / _heads, _stream);
|
||||
|
||||
// attention scores, q*k
|
||||
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
|
||||
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr,
|
||||
_cublasHandle);
|
||||
|
||||
// Softmax + Mask
|
||||
_softmax.reset_size(_heads / pg_size);
|
||||
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, _seq_len, _stream, true);
|
||||
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len,
|
||||
_seq_len, _stream, true);
|
||||
|
||||
// attn prob dropout.
|
||||
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, _batch_heads * _seq_len * _seq_len,
|
||||
_stream);
|
||||
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr,
|
||||
_batch_heads * _seq_len * _seq_len, _stream);
|
||||
|
||||
// attention context, score * v
|
||||
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle);
|
||||
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr,
|
||||
_cublasHandle);
|
||||
|
||||
// [b, nh, s, ad] -> [b, s, nh, ad]
|
||||
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, _hidden_size / pg_size,
|
||||
_heads / pg_size, 1, _stream);
|
||||
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len,
|
||||
_hidden_size / pg_size, _heads / pg_size, 1,
|
||||
_stream);
|
||||
|
||||
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
|
||||
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, output_ptr, _cublasHandle);
|
||||
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr,
|
||||
output_ptr, _cublasHandle);
|
||||
|
||||
// allreduce
|
||||
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
|
||||
|
@ -88,24 +102,27 @@ void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr, const T *input_mas
|
|||
if (typeid(T) != typeid(float)) {
|
||||
data_type = torch::kHalf;
|
||||
}
|
||||
auto output_tensor =
|
||||
torch::from_blob(output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::TensorOptions(torch::kCUDA).dtype(data_type));
|
||||
auto output_tensor = torch::from_blob(
|
||||
output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::TensorOptions(torch::kCUDA).dtype(data_type));
|
||||
std::vector<torch::Tensor> allreduce_tensors = {output_tensor};
|
||||
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
|
||||
work->wait();
|
||||
}
|
||||
|
||||
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, _attn_ob_ptr,
|
||||
_batch_tokens, _hidden_size, _stream);
|
||||
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr,
|
||||
_attn_ob_ptr, _batch_tokens, _hidden_size,
|
||||
_stream);
|
||||
if (!_pre_or_postLayerNorm) {
|
||||
// in-place ln since ln-input will not be used in post-ln mode
|
||||
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, _stream);
|
||||
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr,
|
||||
_batch_tokens, _stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr) {
|
||||
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
|
||||
T *out_ptr) {
|
||||
_stream = Context::Instance().get_stream();
|
||||
_cublasHandle = Context::Instance().get_cublashandle();
|
||||
T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim
|
||||
|
@ -114,8 +131,11 @@ void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr,
|
||||
const T *grad_output_ptr, T *grad_input_ptr, T *buffer) {
|
||||
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
|
||||
const T *input_mask_ptr,
|
||||
const T *output_ptr,
|
||||
const T *grad_output_ptr,
|
||||
T *grad_input_ptr, T *buffer) {
|
||||
cudaStream_t streams[2] = {_stream, _stream};
|
||||
|
||||
const T *q_tf_ptr = _qkv_ptr;
|
||||
|
@ -137,45 +157,57 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
|
|||
// batch_size * head_num * seq_len * seq_len);
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_output_ptr,
|
||||
_batch_tokens, _hidden_size, _stream);
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
|
||||
grad_output_ptr, _batch_tokens,
|
||||
_hidden_size, _stream);
|
||||
} else {
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, grad_output_ptr,
|
||||
nullptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_residual_ptr,
|
||||
_batch_tokens, _hidden_size, _stream);
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr,
|
||||
grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr,
|
||||
_attn_nb_ptr, _batch_tokens, streams);
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
|
||||
grad_residual_ptr, _batch_tokens,
|
||||
_hidden_size, _stream);
|
||||
}
|
||||
|
||||
// bw of output project
|
||||
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
|
||||
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, _attn_ow_ptr,
|
||||
_grad_attn_ow_ptr, _grad_attn_ob_ptr, _cublasHandle, _stream,
|
||||
grad_input_buf_ptr, nullptr, false);
|
||||
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size, _seq_len,
|
||||
_hidden_size / pg_size, _heads / pg_size, _stream);
|
||||
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr,
|
||||
_attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr,
|
||||
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
|
||||
false);
|
||||
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size,
|
||||
_seq_len, _hidden_size / pg_size, _heads / pg_size,
|
||||
_stream);
|
||||
|
||||
// bw of score * v
|
||||
_attn_context.Backward(_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
|
||||
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
|
||||
_attn_context.Backward(
|
||||
_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
|
||||
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
|
||||
|
||||
_attn_prob_dropout.d_dropout(grad_softmax_ptr, _batch_heads * _seq_len * _seq_len, _stream);
|
||||
_attn_prob_dropout.d_dropout(grad_softmax_ptr,
|
||||
_batch_heads * _seq_len * _seq_len, _stream);
|
||||
|
||||
_softmax.reset_size(_heads / pg_size);
|
||||
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, _seq_len, _stream);
|
||||
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len,
|
||||
_seq_len, _stream);
|
||||
|
||||
// bw of q * k
|
||||
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle,
|
||||
grad_qkv_5d_ptr + _batch_dim / pg_size, grad_qkv_5d_ptr);
|
||||
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr,
|
||||
_cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size,
|
||||
grad_qkv_5d_ptr);
|
||||
|
||||
// [3, b, nh, s, ad] -> [b, s, 3, h]
|
||||
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, _seq_len,
|
||||
_hidden_size / pg_size, _heads / pg_size, 3, _stream);
|
||||
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size,
|
||||
_seq_len, _hidden_size / pg_size, _heads / pg_size,
|
||||
3, _stream);
|
||||
|
||||
const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
const T *gemmQKV_inp_ptr =
|
||||
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
|
||||
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, _attn_qkvw_ptr,
|
||||
_grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, _cublasHandle, _stream,
|
||||
grad_input_buf_ptr, nullptr, true);
|
||||
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr,
|
||||
_attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr,
|
||||
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
|
||||
true);
|
||||
|
||||
// allreduce
|
||||
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
|
||||
|
@ -185,7 +217,8 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
|
|||
data_type = torch::kHalf;
|
||||
}
|
||||
auto grad_input_tensor =
|
||||
torch::from_blob(grad_input_buf_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::from_blob(grad_input_buf_ptr,
|
||||
{int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::TensorOptions(torch::kCUDA).dtype(data_type));
|
||||
std::vector<torch::Tensor> allreduce_tensors = {grad_input_tensor};
|
||||
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
|
||||
|
@ -193,19 +226,21 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
|
|||
}
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, grad_input_buf_ptr,
|
||||
grad_output_ptr, gemmQKV_inp_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens,
|
||||
streams);
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr,
|
||||
grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr,
|
||||
_attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
|
||||
} else {
|
||||
// FIXME later
|
||||
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, _batch_size,
|
||||
_seq_len, _hidden_size, _stream);
|
||||
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr,
|
||||
_batch_size, _seq_len, _hidden_size, _stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr,
|
||||
const T *input_mask_ptr, T *grad_input_ptr) {
|
||||
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr,
|
||||
const T *input_ptr, const T *output_ptr,
|
||||
const T *input_mask_ptr,
|
||||
T *grad_input_ptr) {
|
||||
_stream = Context::Instance().get_stream();
|
||||
_cublasHandle = Context::Instance().get_cublashandle();
|
||||
T *buffer = _shared_mem_ptr;
|
||||
|
@ -215,7 +250,8 @@ void MultiHeadAttention<T>::Backward(const T *grad_output_ptr, const T *input_pt
|
|||
4 * _batch_dim + max(3 * _batch_dim,
|
||||
_batch_size * _head_num * _seq_len * _seq_len);
|
||||
*/
|
||||
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, grad_input_ptr, buffer);
|
||||
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr,
|
||||
grad_input_ptr, buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -233,7 +269,8 @@ template class MultiHeadAttention<__half>;
|
|||
|
||||
// x is torch::Tensor
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
@ -241,15 +278,17 @@ template class MultiHeadAttention<__half>;
|
|||
static std::unordered_map<int, std::shared_ptr<void>> s_multihead_attention;
|
||||
|
||||
template <typename T>
|
||||
int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_len, int hidden_dim,
|
||||
int num_heads, float attn_prob_dropout_ratio,
|
||||
float hidden_dropout_ratio, bool pre_or_postLayerNorm,
|
||||
int create_multihead_attention(int layer_id, int max_batch_tokens,
|
||||
int max_seq_len, int hidden_dim, int num_heads,
|
||||
float attn_prob_dropout_ratio,
|
||||
float hidden_dropout_ratio,
|
||||
bool pre_or_postLayerNorm,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
Context::Instance().set_stream(stream);
|
||||
auto layer = std::make_shared<MultiHeadAttention<T>>(
|
||||
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, attn_prob_dropout_ratio,
|
||||
hidden_dropout_ratio, pre_or_postLayerNorm);
|
||||
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads,
|
||||
attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm);
|
||||
|
||||
layer->SetPG(pg_);
|
||||
|
||||
|
@ -261,15 +300,12 @@ int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_l
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Tensor &input,
|
||||
const torch::Tensor &input_mask,
|
||||
const torch::Tensor &in_proj_weight,
|
||||
const torch::Tensor &in_proj_bias,
|
||||
const torch::Tensor &out_proj_weight,
|
||||
const torch::Tensor &out_proj_bias,
|
||||
const torch::Tensor &norm_weight,
|
||||
const torch::Tensor &norm_bias,
|
||||
bool training_mode, bool prelayernorm) {
|
||||
std::vector<torch::Tensor> multihead_attention_fw(
|
||||
int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask,
|
||||
const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias,
|
||||
const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias,
|
||||
const torch::Tensor &norm_weight, const torch::Tensor &norm_bias,
|
||||
bool training_mode, bool prelayernorm) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(input_mask);
|
||||
|
||||
|
@ -280,7 +316,8 @@ std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Ten
|
|||
T *out_ptr = (T *)output.data_ptr();
|
||||
|
||||
std::shared_ptr<MultiHeadAttention<T>> layer =
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(s_multihead_attention[layer_id]);
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(
|
||||
s_multihead_attention[layer_id]);
|
||||
layer->set_cur_batch_shape(input.size(0), input.size(1));
|
||||
layer->SetTrainingMode(training_mode);
|
||||
|
||||
|
@ -297,17 +334,13 @@ std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Ten
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
|
||||
const torch::Tensor &grad_dec_output,
|
||||
const torch::Tensor &output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &input_mask,
|
||||
const torch::Tensor &in_proj_weight,
|
||||
const torch::Tensor &in_proj_bias,
|
||||
const torch::Tensor &out_proj_weight,
|
||||
const torch::Tensor &out_proj_bias,
|
||||
const torch::Tensor &norm_weight,
|
||||
const torch::Tensor &norm_bias) {
|
||||
std::vector<torch::Tensor> multihead_attention_bw(
|
||||
int layer_id, const torch::Tensor &grad_dec_output,
|
||||
const torch::Tensor &output, const torch::Tensor &input,
|
||||
const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight,
|
||||
const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight,
|
||||
const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight,
|
||||
const torch::Tensor &norm_bias) {
|
||||
auto g_output = grad_dec_output.contiguous();
|
||||
CHECK_INPUT(g_output);
|
||||
CHECK_INPUT(output);
|
||||
|
@ -332,7 +365,8 @@ std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
|
|||
T *grad_input_ptr = (T *)grad_input.data_ptr();
|
||||
|
||||
std::shared_ptr<MultiHeadAttention<T>> layer =
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(s_multihead_attention[layer_id]);
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(
|
||||
s_multihead_attention[layer_id]);
|
||||
layer->set_cur_batch_shape(g_output.size(0), g_output.size(1));
|
||||
|
||||
layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr();
|
||||
|
@ -342,10 +376,12 @@ std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
|
|||
layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr();
|
||||
layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr();
|
||||
|
||||
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, grad_input_ptr);
|
||||
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr,
|
||||
grad_input_ptr);
|
||||
|
||||
return {grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
|
||||
grad_out_proj_bias, grad_norm_weight, grad_norm_bias};
|
||||
return {grad_input, grad_in_proj_weight, grad_in_proj_bias,
|
||||
grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight,
|
||||
grad_norm_bias};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
|
|
Loading…
Reference in New Issue