diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index 05c6ee35b..24e458bb3 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -180,4 +180,9 @@ class FusedScaleMaskSoftmax(nn.Module): return probs def get_batch_per_block(self, sq, sk, b, np): + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() + return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)