mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
13 lines
306 B
13 lines
306 B
2 years ago
|
from torch import Tensor
|
||
|
|
||
|
def forward(input: Tensor, mask: Tensor, scale: float) -> Tensor:
|
||
|
...
|
||
|
|
||
|
|
||
|
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
|
||
|
...
|
||
|
|
||
|
|
||
|
def get_batch_per_block(query_seq_len: int, key_seq_len: int, batches: int, attn_heads: int) -> int:
|
||
|
...
|