mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
2.6 KiB
55 lines
2.6 KiB
from typing import List |
|
|
|
from torch import Tensor |
|
from torch.distributed import ProcessGroup |
|
|
|
def multihead_attention_fw_fp32(layer_id: int, input: Tensor, input_mask: Tensor, |
|
in_proj_weight: Tensor, in_proj_bias: Tensor, |
|
out_proj_weight: Tensor, out_proj_bias: Tensor, |
|
norm_weight: Tensor, norm_bias: Tensor, |
|
training_mode: bool, prelayernorm: bool) -> List[Tensor]: |
|
... |
|
|
|
|
|
def multihead_attention_fw_fp16(layer_id: int, input: Tensor, input_mask: Tensor, |
|
in_proj_weight: Tensor, in_proj_bias: Tensor, |
|
out_proj_weight: Tensor, out_proj_bias: Tensor, |
|
norm_weight: Tensor, norm_bias: Tensor, |
|
training_mode: bool, prelayernorm: bool) -> List[Tensor]: |
|
... |
|
|
|
|
|
def multihead_attention_bw_fp32(layer_id: int, grad_dec_output: Tensor, |
|
output: Tensor, input: Tensor, |
|
input_mask: Tensor, in_proj_weight: Tensor, |
|
in_proj_bias: Tensor, out_proj_weight: Tensor, |
|
out_proj_bias: Tensor, norm_weight: Tensor, |
|
norm_bias: Tensor) -> List[Tensor]: |
|
... |
|
|
|
|
|
def multihead_attention_bw_fp16(layer_id: int, grad_dec_output: Tensor, |
|
output: Tensor, input: Tensor, |
|
input_mask: Tensor, in_proj_weight: Tensor, |
|
in_proj_bias: Tensor, out_proj_weight: Tensor, |
|
out_proj_bias: Tensor, norm_weight: Tensor, |
|
norm_bias: Tensor) -> List[Tensor]: |
|
... |
|
|
|
|
|
def create_multihead_attention_fp32(layer_id: int, max_batch_tokens: int, |
|
max_seq_len: int, hidden_dim: int, num_heads: int, |
|
attn_prob_dropout_ratio: float, |
|
hidden_dropout_ratio: float, |
|
pre_or_postLayerNorm: bool, |
|
pg: ProcessGroup) -> int: |
|
... |
|
|
|
|
|
def create_multihead_attention_fp16(layer_id: int, max_batch_tokens: int, |
|
max_seq_len: int, hidden_dim: int, num_heads: int, |
|
attn_prob_dropout_ratio: float, |
|
hidden_dropout_ratio: float, |
|
pre_or_postLayerNorm: bool, |
|
pg: ProcessGroup) -> int: |
|
...
|
|
|