[builder] builder for scaled_upper_triang_masked_softmax (#2234)

pull/2237/head
Jiarui Fang 2022-12-30 09:58:00 +08:00 committed by GitHub
parent 31fe84237b
commit db4cbdc7fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 53 additions and 18 deletions

View File

@ -18,6 +18,13 @@ except ImportError:
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
multihead_attention = MultiHeadAttnBuilder().load()
try:
from colossalai._C import scaled_upper_triang_masked_softmax
except ImportError:
from colossalai.kernel.op_builder import ScaledSoftmaxBuilder
scaled_upper_triang_masked_softmax = ScaledSoftmaxBuilder().load()
__all__ = [
"fused_optim", "cpu_optim", "multihead_attention", "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"
"fused_optim", "cpu_optim", "multihead_attention", "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention",
"scaled_upper_triang_masked_softmax"
]

View File

@ -23,27 +23,20 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, scale):
try:
import colossalai._C.scaled_upper_triang_masked_softmax
except ImportError:
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
from colossalai.kernel import scaled_upper_triang_masked_softmax
scale_t = torch.tensor([scale])
softmax_results = colossalai._C.scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
try:
import colossalai._C.scaled_upper_triang_masked_softmax
except ImportError:
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
from colossalai.kernel import scaled_upper_triang_masked_softmax
softmax_results, scale_t = ctx.saved_tensors
input_grads = colossalai._C.scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results,
scale_t[0])
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None

View File

@ -1,5 +1,6 @@
from .cpu_adam import CPUAdamBuilder
from .fused_optim import FusedOptimBuilder
from .multi_head_attn import MultiHeadAttnBuilder
from .scaled_upper_triang_masked_softmax import ScaledSoftmaxBuilder
__all__ = ['CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder']
__all__ = ['CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledSoftmaxBuilder']

View File

@ -0,0 +1,36 @@
import os
from .builder import Builder, get_cuda_cc_flag
class ScaledSoftmaxBuilder(Builder):
def __init__(self):
self.base_dir = "cuda_native/csrc"
self.name = 'scaled_upper_triang_masked_softmax'
super().__init__()
def include_dirs(self):
ret = []
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()]
ret.append(os.path.join(self.base_dir, "kernels", "include"))
return [self.colossalai_src_path(path) for path in ret]
def sources_files(self):
ret = [
os.path.join(self.base_dir, fname)
for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu']
]
return [self.colossalai_src_path(path) for path in ret]
def cxx_flags(self):
return ['-O3']
def nvcc_flags(self):
extra_cuda_flags = [
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr',
'--expt-extended-lambda'
]
extra_cuda_flags.extend(get_cuda_cc_flag())
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
return ret

View File

@ -324,7 +324,7 @@ def main():
if n >= WARMUP_STEPS:
tflops_list.append(step_tflops)
logger.info(f"max memory {torch.cuda.memory_allocated() / 1024**2} MB", ranks=[0])
logger.info(f"max memory {torch.cuda.max_memory_allocated() / 1024**2} MB", ranks=[0])
tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS

View File

@ -154,10 +154,8 @@ if build_cuda_ext:
'--expt-extended-lambda'
]
ext_modules.append(
cuda_ext_helper('colossalai._C.scaled_upper_triang_masked_softmax',
['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'],
extra_cuda_flags + cc_flag))
from colossalai.kernel.op_builder import ScaledSoftmaxBuilder
ext_modules.append(ScaledSoftmaxBuilder().builder('colossalai._C.scaled_upper_triang_masked_softmax'))
ext_modules.append(
cuda_ext_helper('colossalai._C.scaled_masked_softmax',