[kernel] added kernel loader to softmax autograd function (#3093)

* [kernel] added kernel loader to softmax autograd function

* [release] v0.2.6
pull/3094/head
Frank Lee 2023-03-10 14:27:09 +08:00 committed by GitHub
parent fff98f06ed
commit 95a36eae63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 0 deletions

View File

@ -180,4 +180,9 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs
def get_batch_per_block(self, sq, sk, b, np):
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)