mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py code style (#955)
parent
f6970ef8b1
commit
8ca2a85682
|
@ -28,9 +28,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||||
|
|
||||||
scale_t = torch.tensor([scale])
|
scale_t = torch.tensor([scale])
|
||||||
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(
|
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||||
inputs, scale_t[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.save_for_backward(softmax_results, scale_t)
|
ctx.save_for_backward(softmax_results, scale_t)
|
||||||
return softmax_results
|
return softmax_results
|
||||||
|
@ -43,9 +41,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||||
|
|
||||||
softmax_results, scale_t = ctx.saved_tensors
|
softmax_results, scale_t = ctx.saved_tensors
|
||||||
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(
|
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||||
output_grads, softmax_results, scale_t[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
return input_grads, None
|
return input_grads, None
|
||||||
|
|
||||||
|
@ -81,9 +77,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
|
||||||
|
|
||||||
softmax_results, scale_t = ctx.saved_tensors
|
softmax_results, scale_t = ctx.saved_tensors
|
||||||
|
|
||||||
input_grads = colossal_scaled_masked_softmax.backward(
|
input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||||
output_grads, softmax_results, scale_t[0]
|
|
||||||
)
|
|
||||||
return input_grads, None, None
|
return input_grads, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,9 +108,8 @@ class FusedScaleMaskSoftmax(nn.Module):
|
||||||
super(FusedScaleMaskSoftmax, self).__init__()
|
super(FusedScaleMaskSoftmax, self).__init__()
|
||||||
self.input_in_fp16 = input_in_fp16
|
self.input_in_fp16 = input_in_fp16
|
||||||
self.input_in_bf16 = input_in_bf16
|
self.input_in_bf16 = input_in_bf16
|
||||||
assert not (
|
assert not (self.input_in_fp16
|
||||||
self.input_in_fp16 and self.input_in_bf16
|
and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time."
|
||||||
), "both fp16 and bf16 flags cannot be active at the same time."
|
|
||||||
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
|
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
|
||||||
self.attn_mask_type = attn_mask_type
|
self.attn_mask_type = attn_mask_type
|
||||||
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
||||||
|
@ -124,9 +117,7 @@ class FusedScaleMaskSoftmax(nn.Module):
|
||||||
self.softmax_in_fp32 = softmax_in_fp32
|
self.softmax_in_fp32 = softmax_in_fp32
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
assert (
|
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
|
||||||
self.scale is None or softmax_in_fp32
|
|
||||||
), "softmax should be in fp32 when scaled"
|
|
||||||
|
|
||||||
def forward(self, input, mask):
|
def forward(self, input, mask):
|
||||||
# [b, np, sq, sk]
|
# [b, np, sq, sk]
|
||||||
|
@ -140,14 +131,13 @@ class FusedScaleMaskSoftmax(nn.Module):
|
||||||
def is_kernel_available(self, mask, b, np, sq, sk):
|
def is_kernel_available(self, mask, b, np, sq, sk):
|
||||||
attn_batches = b * np
|
attn_batches = b * np
|
||||||
|
|
||||||
if (
|
if (self.scaled_masked_softmax_fusion # user want to fuse
|
||||||
self.scaled_masked_softmax_fusion # user want to fuse
|
and self.input_in_float16 # input must be fp16
|
||||||
and self.input_in_float16 # input must be fp16
|
and mask is not None # mask tensor must not be None
|
||||||
and mask is not None # mask tensor must not be None
|
and 16 < sk <= 2048 # sk must be 16 ~ 2048
|
||||||
and 16 < sk <= 2048 # sk must be 16 ~ 2048
|
and sq % 4 == 0 # sq must be divisor of 4
|
||||||
and sq % 4 == 0 # sq must be divisor of 4
|
and attn_batches % 4 == 0 # np * b must be divisor of 4
|
||||||
and attn_batches % 4 == 0 # np * b must be divisor of 4
|
):
|
||||||
):
|
|
||||||
if 0 <= sk <= 2048:
|
if 0 <= sk <= 2048:
|
||||||
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
|
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue