ColossalAI/colossalai/_C/scaled_upper_triang_masked_...

9 lines
181 B
Python
Raw Normal View History

from torch import Tensor
def forward(input: Tensor, scale: float) -> Tensor:
...
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
...