mirror of https://github.com/hpcaitech/ColossalAI
[builder] MOE builder (#2277)
parent
26e171af6c
commit
16cc8e6aa7
|
@ -24,7 +24,19 @@ except ImportError:
|
||||||
from colossalai.kernel.op_builder import ScaledSoftmaxBuilder
|
from colossalai.kernel.op_builder import ScaledSoftmaxBuilder
|
||||||
scaled_upper_triang_masked_softmax = ScaledSoftmaxBuilder().load()
|
scaled_upper_triang_masked_softmax = ScaledSoftmaxBuilder().load()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai._C import moe
|
||||||
|
except ImportError:
|
||||||
|
from colossalai.kernel.op_builder import MOEBuilder
|
||||||
|
moe = MOEBuilder().load()
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"fused_optim", "cpu_optim", "multihead_attention", "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention",
|
"fused_optim",
|
||||||
"scaled_upper_triang_masked_softmax"
|
"cpu_optim",
|
||||||
|
"multihead_attention",
|
||||||
|
"moe",
|
||||||
|
"LayerNorm",
|
||||||
|
"FusedScaleMaskSoftmax",
|
||||||
|
"MultiHeadAttention",
|
||||||
|
"scaled_upper_triang_masked_softmax",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from .cpu_adam import CPUAdamBuilder
|
from .cpu_adam import CPUAdamBuilder
|
||||||
from .fused_optim import FusedOptimBuilder
|
from .fused_optim import FusedOptimBuilder
|
||||||
|
from .moe import MOEBuilder
|
||||||
from .multi_head_attn import MultiHeadAttnBuilder
|
from .multi_head_attn import MultiHeadAttnBuilder
|
||||||
from .scaled_upper_triang_masked_softmax import ScaledSoftmaxBuilder
|
from .scaled_upper_triang_masked_softmax import ScaledSoftmaxBuilder
|
||||||
|
|
||||||
__all__ = ['CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledSoftmaxBuilder']
|
__all__ = ['CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledSoftmaxBuilder', 'MOEBuilder']
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_cc_flag():
|
def get_cuda_cc_flag() -> List:
|
||||||
"""get_cuda_cc_flag
|
"""get_cuda_cc_flag
|
||||||
|
|
||||||
cc flag for your GPU arch
|
cc flag for your GPU arch
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .builder import Builder, get_cuda_cc_flag
|
||||||
|
|
||||||
|
|
||||||
|
class MOEBuilder(Builder):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.base_dir = "cuda_native/csrc"
|
||||||
|
self.name = 'moe'
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def include_dirs(self):
|
||||||
|
ret = []
|
||||||
|
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()]
|
||||||
|
ret.append(os.path.join(self.base_dir, "kernels", "include"))
|
||||||
|
return [self.colossalai_src_path(path) for path in ret]
|
||||||
|
|
||||||
|
def sources_files(self):
|
||||||
|
ret = [os.path.join(self.base_dir, fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']]
|
||||||
|
return [self.colossalai_src_path(path) for path in ret]
|
||||||
|
|
||||||
|
def cxx_flags(self):
|
||||||
|
return ['-O3', '-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||||
|
|
||||||
|
def nvcc_flags(self):
|
||||||
|
extra_cuda_flags = [
|
||||||
|
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr',
|
||||||
|
'--expt-extended-lambda'
|
||||||
|
]
|
||||||
|
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||||
|
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
|
||||||
|
return ret
|
|
@ -6,12 +6,7 @@ from torch import Tensor
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
COL_MOE_KERNEL_FLAG = False
|
COL_MOE_KERNEL_FLAG = False
|
||||||
try:
|
from colossalai.kernel import moe
|
||||||
import colossalai._C.moe
|
|
||||||
|
|
||||||
COL_MOE_KERNEL_FLAG = True
|
|
||||||
except ImportError:
|
|
||||||
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
|
||||||
|
|
||||||
|
|
||||||
class AllGather(torch.autograd.Function):
|
class AllGather(torch.autograd.Function):
|
||||||
|
@ -90,7 +85,7 @@ class MoeDispatch(torch.autograd.Function):
|
||||||
s = tokens.size(0)
|
s = tokens.size(0)
|
||||||
h = tokens.size(1)
|
h = tokens.size(1)
|
||||||
|
|
||||||
expert_input = colossalai._C.moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||||
|
|
||||||
ctx.save_for_backward(mask, dest_idx)
|
ctx.save_for_backward(mask, dest_idx)
|
||||||
ctx.s = s
|
ctx.s = s
|
||||||
|
@ -102,7 +97,7 @@ class MoeDispatch(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, output_grad):
|
def backward(ctx, output_grad):
|
||||||
mask, dest_idx = ctx.saved_tensors
|
mask, dest_idx = ctx.saved_tensors
|
||||||
d_tokens = colossalai._C.moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||||
return d_tokens, None, None, None
|
return d_tokens, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,7 +114,7 @@ class MoeCombine(torch.autograd.Function):
|
||||||
|
|
||||||
fp16_flag = (expert_tokens.dtype == torch.float16)
|
fp16_flag = (expert_tokens.dtype == torch.float16)
|
||||||
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
||||||
ctokens = colossalai._C.moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
||||||
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
||||||
|
|
||||||
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||||||
|
@ -138,8 +133,7 @@ class MoeCombine(torch.autograd.Function):
|
||||||
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
||||||
else tokens_grad
|
else tokens_grad
|
||||||
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
||||||
d_expert, d_logits = colossalai._C.moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
|
d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx)
|
||||||
mask, dest_idx)
|
|
||||||
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
||||||
|
|
||||||
return d_expert, d_logits, None, None, None
|
return d_expert, d_logits, None, None, None
|
||||||
|
@ -149,6 +143,6 @@ def moe_cumsum(inputs: Tensor):
|
||||||
dim0 = inputs.size(0)
|
dim0 = inputs.size(0)
|
||||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||||
if flag and COL_MOE_KERNEL_FLAG:
|
if flag and COL_MOE_KERNEL_FLAG:
|
||||||
return colossalai._C.moe.cumsum_sub_one(inputs)
|
return moe.cumsum_sub_one(inputs)
|
||||||
else:
|
else:
|
||||||
return torch.cumsum(inputs, dim=0) - 1
|
return torch.cumsum(inputs, dim=0) - 1
|
||||||
|
|
6
setup.py
6
setup.py
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from setuptools import Extension, find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
from colossalai.kernel.op_builder.utils import get_cuda_bare_metal_version
|
from colossalai.kernel.op_builder.utils import get_cuda_bare_metal_version
|
||||||
|
|
||||||
|
@ -161,8 +161,8 @@ if build_cuda_ext:
|
||||||
cuda_ext_helper('colossalai._C.scaled_masked_softmax',
|
cuda_ext_helper('colossalai._C.scaled_masked_softmax',
|
||||||
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag))
|
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag))
|
||||||
|
|
||||||
ext_modules.append(
|
from colossalai.kernel.op_builder import MOEBuilder
|
||||||
cuda_ext_helper('colossalai._C.moe', ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], extra_cuda_flags + cc_flag))
|
ext_modules.append(MOEBuilder().builder('colossalai._C.moe'))
|
||||||
|
|
||||||
extra_cuda_flags = ['-maxrregcount=50']
|
extra_cuda_flags = ['-maxrregcount=50']
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue