diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h index 1dd84773a..70b3419d8 100644 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h @@ -19,21 +19,25 @@ template class MultiHeadAttention { public: - MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size, - int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio, + MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, + int hidden_size, int num_heads, float attn_dropout_ratio, + float hidden_output_dropout_ratio, bool pre_or_postLayerNorm); virtual ~MultiHeadAttention(); void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); - void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr, - const T *input_mask_ptr, T *grad_input_ptr); + void Backward(const T *grad_output_ptr, const T *input_ptr, + const T *output_ptr, const T *input_mask_ptr, + T *grad_input_ptr); - void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer); + void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, + T *buffer); - void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr, - const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer); + void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, + const T *output_ptr, const T *grad_output_ptr, + T *grad_input_attn_layer_bwptr, T *buffer); void set_cur_batch_shape(int batch_size, int seq_len) { _batch_size = batch_size; @@ -83,14 +87,17 @@ class MultiHeadAttention { } _qkv_ptr = cuda_malloc(_max_batch_tokens * _hidden_size * 3); - _soft_out_ptr = cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); - _ctx_bufB_ptr = cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); + _soft_out_ptr = + cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); + _ctx_bufB_ptr = + cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); _attn_o_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); // buffer size needed by attn bw - size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size + - std::max(3 * _max_batch_tokens * _hidden_size / pg_size, - _max_batch_tokens * _heads / pg_size * _max_seq_len); + size_t smem_size = + 4 * _max_batch_tokens * _hidden_size / pg_size + + std::max(3 * _max_batch_tokens * _hidden_size / pg_size, + _max_batch_tokens * _heads / pg_size * _max_seq_len); if (!_shared_mem_ptr) { cuda_free(_shared_mem_ptr);