From 40d376c566f6f4e7fa8a7ae63d9c9b4f6178413c Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 6 Jan 2023 20:50:26 +0800 Subject: [PATCH] [setup] support pre-build and jit-build of cuda kernels (#2374) * [setup] support pre-build and jit-build of cuda kernels * polish code * polish code * polish code * polish code * polish code * polish code --- .gitignore | 7 + colossalai/_C/__init__.py | 0 colossalai/_C/__init__.pyi | 9 - colossalai/_C/cpu_optim.pyi | 8 - colossalai/_C/fused_optim.pyi | 23 --- colossalai/_C/layer_norm.pyi | 11 -- colossalai/_C/moe.pyi | 20 --- colossalai/_C/multihead_attention.pyi | 55 ------ colossalai/_C/scaled_masked_softmax.pyi | 12 -- .../_C/scaled_upper_triang_masked_softmax.pyi | 8 - colossalai/amp/naive_amp/_fp16_optimizer.py | 16 +- colossalai/kernel/__init__.py | 35 ---- .../kernel/cuda_native/multihead_attention.py | 3 +- colossalai/nn/layer/moe/_operation.py | 29 +++- colossalai/nn/optimizer/cpu_adam.py | 9 +- colossalai/nn/optimizer/fused_adam.py | 3 +- colossalai/nn/optimizer/fused_lamb.py | 3 +- colossalai/nn/optimizer/fused_sgd.py | 3 +- colossalai/nn/optimizer/hybrid_adam.py | 5 +- colossalai/utils/common.py | 13 +- .../multi_tensor_apply/multi_tensor_apply.py | 1 - op_builder/README.md | 31 ++++ op_builder/__init__.py | 20 ++- op_builder/builder.py | 159 ++++++++++++------ op_builder/cpu_adam.py | 18 +- op_builder/fused_optim.py | 23 ++- op_builder/layernorm.py | 29 ++++ op_builder/moe.py | 27 +-- op_builder/multi_head_attn.py | 28 +-- op_builder/scaled_masked_softmax.py | 37 ++++ .../scaled_upper_triang_masked_softmax.py | 36 ---- .../scaled_upper_triangle_masked_softmax.py | 37 ++++ op_builder/utils.py | 22 +++ setup.py | 58 +------ tests/test_optimizer/test_cpu_adam.py | 3 +- .../test_optimizer/test_fused_adam_kernel.py | 3 +- 36 files changed, 414 insertions(+), 390 deletions(-) create mode 100644 colossalai/_C/__init__.py delete mode 100644 colossalai/_C/__init__.pyi delete mode 100644 colossalai/_C/cpu_optim.pyi delete mode 100644 colossalai/_C/fused_optim.pyi delete mode 100644 colossalai/_C/layer_norm.pyi delete mode 100644 colossalai/_C/moe.pyi delete mode 100644 colossalai/_C/multihead_attention.pyi delete mode 100644 colossalai/_C/scaled_masked_softmax.pyi delete mode 100644 colossalai/_C/scaled_upper_triang_masked_softmax.pyi create mode 100644 op_builder/README.md create mode 100644 op_builder/layernorm.py create mode 100644 op_builder/scaled_masked_softmax.py delete mode 100644 op_builder/scaled_upper_triang_masked_softmax.py create mode 100644 op_builder/scaled_upper_triangle_masked_softmax.py diff --git a/.gitignore b/.gitignore index 40f3f6deb..6b6f980e3 100644 --- a/.gitignore +++ b/.gitignore @@ -144,3 +144,10 @@ docs/.build # ignore version.py generated by setup.py colossalai/version.py + +# ignore any kernel build files +.o +.so + +# ignore python interface defition file +.pyi diff --git a/colossalai/_C/__init__.py b/colossalai/_C/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/_C/__init__.pyi b/colossalai/_C/__init__.pyi deleted file mode 100644 index bfd86d0ee..000000000 --- a/colossalai/_C/__init__.pyi +++ /dev/null @@ -1,9 +0,0 @@ -from . import ( - cpu_optim, - fused_optim, - layer_norm, - moe, - multihead_attention, - scaled_masked_softmax, - scaled_upper_triang_masked_softmax, -) diff --git a/colossalai/_C/cpu_optim.pyi b/colossalai/_C/cpu_optim.pyi deleted file mode 100644 index 0f7611790..000000000 --- a/colossalai/_C/cpu_optim.pyi +++ /dev/null @@ -1,8 +0,0 @@ -from torch import Tensor - -class CPUAdamOptimizer: - def __init__(self, lr: float, beta1: float, beta2: float, eps: float, - weight_decay: float, adamw_mode: float) -> None: ... - - def step(self, step: int, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, bias_correction: bool, - param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor, loss_scale: float) -> None: ... diff --git a/colossalai/_C/fused_optim.pyi b/colossalai/_C/fused_optim.pyi deleted file mode 100644 index 983b02335..000000000 --- a/colossalai/_C/fused_optim.pyi +++ /dev/null @@ -1,23 +0,0 @@ -from typing import List - -from torch import Tensor - -def multi_tensor_scale(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], scale: float) -> None: - ... - - -def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], weight_decay: float, - momentum: float, dampening: float, lr: float, nesterov: bool, first_run: bool, weight_decay_after_momentum: bool, scale: float) -> None: - ... - - -def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float, div_scale: float) -> None: - ... - - -def multi_tensor_lamb(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, bias_correction: int, weight_decay: float, grad_averaging: int, mode: int, global_grad_norm: Tensor, max_grad_norm: float, use_nvlamb_python: bool) -> None: - ... - - -def multi_tensor_l2norm(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], per_tensor_python: bool) -> None: - ... diff --git a/colossalai/_C/layer_norm.pyi b/colossalai/_C/layer_norm.pyi deleted file mode 100644 index 02d4587ff..000000000 --- a/colossalai/_C/layer_norm.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import List - -from torch import Tensor - -def forward_affine(input: Tensor, normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]: - ... - - -def backward_affine(dout: Tensor, mean: Tensor, invvar: Tensor, input: Tensor, - normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]: - ... diff --git a/colossalai/_C/moe.pyi b/colossalai/_C/moe.pyi deleted file mode 100644 index 121aa7e41..000000000 --- a/colossalai/_C/moe.pyi +++ /dev/null @@ -1,20 +0,0 @@ -from torch import Tensor - -def cumsum_sub_one(mask: Tensor) -> Tensor: - ... - - -def dispatch_forward(s: int, ec: int, h: int, batch_tokens: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: - ... - - -def dispatch_backward(s: int, ec: int, h: int, expert_grad: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: - ... - - -def combine_forward(s: int, e: int, c: int, h: int, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: - ... - - -def combine_backward(s: int, e: int, c: int, h: int, tokens_grad: Tensor, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor: - ... diff --git a/colossalai/_C/multihead_attention.pyi b/colossalai/_C/multihead_attention.pyi deleted file mode 100644 index 7ad87ea9a..000000000 --- a/colossalai/_C/multihead_attention.pyi +++ /dev/null @@ -1,55 +0,0 @@ -from typing import List - -from torch import Tensor -from torch.distributed import ProcessGroup - -def multihead_attention_fw_fp32(layer_id: int, input: Tensor, input_mask: Tensor, - in_proj_weight: Tensor, in_proj_bias: Tensor, - out_proj_weight: Tensor, out_proj_bias: Tensor, - norm_weight: Tensor, norm_bias: Tensor, - training_mode: bool, prelayernorm: bool) -> List[Tensor]: - ... - - -def multihead_attention_fw_fp16(layer_id: int, input: Tensor, input_mask: Tensor, - in_proj_weight: Tensor, in_proj_bias: Tensor, - out_proj_weight: Tensor, out_proj_bias: Tensor, - norm_weight: Tensor, norm_bias: Tensor, - training_mode: bool, prelayernorm: bool) -> List[Tensor]: - ... - - -def multihead_attention_bw_fp32(layer_id: int, grad_dec_output: Tensor, - output: Tensor, input: Tensor, - input_mask: Tensor, in_proj_weight: Tensor, - in_proj_bias: Tensor, out_proj_weight: Tensor, - out_proj_bias: Tensor, norm_weight: Tensor, - norm_bias: Tensor) -> List[Tensor]: - ... - - -def multihead_attention_bw_fp16(layer_id: int, grad_dec_output: Tensor, - output: Tensor, input: Tensor, - input_mask: Tensor, in_proj_weight: Tensor, - in_proj_bias: Tensor, out_proj_weight: Tensor, - out_proj_bias: Tensor, norm_weight: Tensor, - norm_bias: Tensor) -> List[Tensor]: - ... - - -def create_multihead_attention_fp32(layer_id: int, max_batch_tokens: int, - max_seq_len: int, hidden_dim: int, num_heads: int, - attn_prob_dropout_ratio: float, - hidden_dropout_ratio: float, - pre_or_postLayerNorm: bool, - pg: ProcessGroup) -> int: - ... - - -def create_multihead_attention_fp16(layer_id: int, max_batch_tokens: int, - max_seq_len: int, hidden_dim: int, num_heads: int, - attn_prob_dropout_ratio: float, - hidden_dropout_ratio: float, - pre_or_postLayerNorm: bool, - pg: ProcessGroup) -> int: - ... diff --git a/colossalai/_C/scaled_masked_softmax.pyi b/colossalai/_C/scaled_masked_softmax.pyi deleted file mode 100644 index fdb88266e..000000000 --- a/colossalai/_C/scaled_masked_softmax.pyi +++ /dev/null @@ -1,12 +0,0 @@ -from torch import Tensor - -def forward(input: Tensor, mask: Tensor, scale: float) -> Tensor: - ... - - -def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor: - ... - - -def get_batch_per_block(query_seq_len: int, key_seq_len: int, batches: int, attn_heads: int) -> int: - ... diff --git a/colossalai/_C/scaled_upper_triang_masked_softmax.pyi b/colossalai/_C/scaled_upper_triang_masked_softmax.pyi deleted file mode 100644 index 39a3d6b22..000000000 --- a/colossalai/_C/scaled_upper_triang_masked_softmax.pyi +++ /dev/null @@ -1,8 +0,0 @@ -from torch import Tensor - -def forward(input: Tensor, scale: float) -> Tensor: - ... - - -def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor: - ... diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py index 3f2c4c2ed..e4699f92b 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -8,16 +8,28 @@ from torch.optim import Optimizer from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from colossalai.kernel import fused_optim +from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.logging import get_dist_logger from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier from ._utils import has_inf_or_nan, zero_gard_by_list from .grad_scaler import BaseGradScaler +try: + from colossalai._C import fused_optim +except: + fused_optim = None + __all__ = ['FP16Optimizer'] +def load_fused_optim(): + global fused_optim + + if fused_optim is None: + fused_optim = FusedOptimBuilder().load() + + def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): """ adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM) @@ -30,6 +42,8 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): if overflow_buf: overflow_buf.fill_(0) # Scaling with factor `1.0` is equivalent to copy. + global fused_optim + load_fused_optim() multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0) else: for this_, that_ in zip(this, that): diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 02d000362..8933fc0a3 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,42 +1,7 @@ from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention -try: - from colossalai._C import fused_optim -except: - from colossalai.kernel.op_builder.fused_optim import FusedOptimBuilder - fused_optim = FusedOptimBuilder().load() - -try: - from colossalai._C import cpu_optim -except ImportError: - from colossalai.kernel.op_builder import CPUAdamBuilder - cpu_optim = CPUAdamBuilder().load() - -try: - from colossalai._C import multihead_attention -except ImportError: - from colossalai.kernel.op_builder import MultiHeadAttnBuilder - multihead_attention = MultiHeadAttnBuilder().load() - -try: - from colossalai._C import scaled_upper_triang_masked_softmax -except ImportError: - from colossalai.kernel.op_builder import ScaledSoftmaxBuilder - scaled_upper_triang_masked_softmax = ScaledSoftmaxBuilder().load() - -try: - from colossalai._C import moe -except ImportError: - from colossalai.kernel.op_builder import MOEBuilder - moe = MOEBuilder().load() - __all__ = [ - "fused_optim", - "cpu_optim", - "multihead_attention", - "moe", "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention", - "scaled_upper_triang_masked_softmax", ] diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py index 2c7503453..7df53731e 100644 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -135,7 +135,8 @@ class MultiHeadAttention(nn.Module): # Load cuda modules if needed global colossal_multihead_attention if colossal_multihead_attention is None: - from colossalai.kernel import multihead_attention + from colossalai.kernel.op_builder import MultiHeadAttnBuilder + multihead_attention = MultiHeadAttnBuilder().load() colossal_multihead_attention = multihead_attention # create the layer in cuda kernels. diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index d06025db1..37f31c167 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -6,13 +6,32 @@ from torch import Tensor from torch.distributed import ProcessGroup COL_MOE_KERNEL_FLAG = False -from colossalai.kernel import moe + +try: + from colossalai._C import moe +except: + moe = None + + +def build_moe_if_not_prebuilt(): + # load moe kernel during runtime if not pre-built + global moe + if moe is None: + from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() class AllGather(torch.autograd.Function): @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + + global moe + + if moe is None: + from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() + if ctx is not None: ctx.comm_grp = group @@ -85,6 +104,9 @@ class MoeDispatch(torch.autograd.Function): s = tokens.size(0) h = tokens.size(1) + # load moe kernel during runtime if not pre-built + build_moe_if_not_prebuilt() + expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) ctx.save_for_backward(mask, dest_idx) @@ -112,6 +134,9 @@ class MoeCombine(torch.autograd.Function): c = ec // e h = expert_tokens.size(-1) + # load moe kernel during runtime if not pre-built + build_moe_if_not_prebuilt() + fp16_flag = (expert_tokens.dtype == torch.float16) cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) @@ -143,6 +168,8 @@ def moe_cumsum(inputs: Tensor): dim0 = inputs.size(0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) if flag and COL_MOE_KERNEL_FLAG: + # load moe kernel during runtime if not pre-built + build_moe_if_not_prebuilt() return moe.cumsum_sub_one(inputs) else: return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 5b05fecc8..a8c352279 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -3,6 +3,7 @@ from typing import Optional import torch +from colossalai.kernel.op_builder import CPUAdamBuilder from colossalai.registry import OPTIMIZERS from .nvme_optimizer import NVMeOptimizer @@ -76,12 +77,8 @@ class CPUAdam(NVMeOptimizer): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - try: - import colossalai._C.cpu_optim - except ImportError: - raise ImportError('Please install colossalai from source code to use CPUAdam') - self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, - adamw_mode) + cpu_adam = CPUAdamBuilder().load() + self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) def torch_adam_update(self, data, diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index c81d122d4..2f6bde5ca 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -65,7 +65,8 @@ class FusedAdam(torch.optim.Optimizer): self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: - from colossalai.kernel import fused_optim + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index a78b351fc..891a76da7 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -76,7 +76,8 @@ class FusedLAMB(torch.optim.Optimizer): max_grad_norm=max_grad_norm) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - from colossalai.kernel import fused_optim + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm # Skip buffer diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 2596c0bcd..41e6d5248 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -80,7 +80,8 @@ class FusedSGD(Optimizer): self.wd_after_momentum = wd_after_momentum if multi_tensor_applier.available: - from colossalai.kernel import fused_optim + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() # Skip buffer self._dummy_overflow_buf = torch.tensor([0], diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 5504411aa..5196d4338 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -2,6 +2,7 @@ from typing import Any, Optional import torch +from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder from colossalai.registry import OPTIMIZERS from colossalai.utils import multi_tensor_applier @@ -77,7 +78,9 @@ class HybridAdam(NVMeOptimizer): super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - from colossalai.kernel import cpu_optim, fused_optim + # build during runtime if not found + cpu_optim = CPUAdamBuilder().load() + fused_optim = FusedOptimBuilder().load() self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.gpu_adam_op = fused_optim.multi_tensor_adam diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 3ff72d037..7575fa292 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -18,11 +18,15 @@ from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARA from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env -from colossalai.kernel import fused_optim from colossalai.tensor import ColoParameter, ProcessGroup from .multi_tensor_apply import multi_tensor_applier +try: + from colossalai._C import fused_optim +except: + fused_optim = None + def print_rank_0(msg: str, logger=None): """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. @@ -123,6 +127,13 @@ def is_model_parallel_parameter(p): def _calc_l2_norm(grads): + # we should not + global fused_optim + + if fused_optim is None: + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + norm = 0.0 if len(grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py index b9d98d019..2b6de5fe1 100644 --- a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py +++ b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py @@ -14,7 +14,6 @@ class MultiTensorApply(object): def __init__(self, chunk_size): try: - from colossalai.kernel import fused_optim MultiTensorApply.available = True self.chunk_size = chunk_size except ImportError as err: diff --git a/op_builder/README.md b/op_builder/README.md new file mode 100644 index 000000000..057da1038 --- /dev/null +++ b/op_builder/README.md @@ -0,0 +1,31 @@ +# Build PyTorch Extensions + +## Overview + +Building PyTorch extensions can be a difficult task for users not from the system background. It is definitely frustrating if the users encounter many strange technical jargons when install Colossal-AI. Therefore, we will provide two methods of building the PyTorch extensions for the users. + +1. Build CUDA extensions when running `pip install` if `CUDA_EXT=1` +2. Build the extension during runtime + +The first method is more suitable for users who are familiar with CUDA environment configurations. The second method is for those who are not as they only need to build the kernel which is required by their program. + +These two methods have different advantages and disadvantages. +Method 1 is good because it allows the user to build all kernels during installation and directly import the kernel. They don't need to care about kernel building when running their program. However, installation may fail if they don't know how to configure their environments and this leads to much frustration. +Method 2 is good because it allows the user to only build the kernel they actually need, such that there is a lower probability that they encounter environment issue. However, it may slow down their program due to the first build and subsequence load. + +## PyTorch Extensions in Colossal-AI + +As mentioned in the section above, our aim is to make these two methods coherently supported in Colossal-AI, meaning that for a kernel should be either built in `setup.py` or during runtime. +There are mainly two functions used to build extensions. + +1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`. +2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime + +Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong). + +We have implemented the following conventions: + +1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C` +2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete) + +When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered. diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 08832fc55..5ae7223b8 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -1,7 +1,23 @@ from .cpu_adam import CPUAdamBuilder from .fused_optim import FusedOptimBuilder +from .layernorm import LayerNormBuilder from .moe import MOEBuilder from .multi_head_attn import MultiHeadAttnBuilder -from .scaled_upper_triang_masked_softmax import ScaledSoftmaxBuilder +from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder +from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder -__all__ = ['CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledSoftmaxBuilder', 'MOEBuilder'] +ALL_OPS = { + 'cpu_adam': CPUAdamBuilder, + 'fused_optim': FusedOptimBuilder, + 'moe': MOEBuilder, + 'multi_head_attn': MultiHeadAttnBuilder, + 'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder, + 'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder, + 'layernorm': LayerNormBuilder, +} + +__all__ = [ + 'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder', + 'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder', + 'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder' +] diff --git a/op_builder/builder.py b/op_builder/builder.py index 2e3728397..dc9ea8e11 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -1,40 +1,49 @@ +import importlib import os -import re +import time +from abc import ABC, abstractmethod from pathlib import Path from typing import List -def get_cuda_cc_flag() -> List: - """get_cuda_cc_flag +class Builder(ABC): + """ + Builder is the base class to build extensions for PyTorch. - cc flag for your GPU arch + Args: + name (str): the name of the kernel to be built + prebuilt_import_path (str): the path where the extension is installed during pip install """ - # only import torch when needed - # this is to avoid importing torch when building on a machine without torch pre-installed - # one case is to build wheel for pypi release - import torch - - cc_flag = [] - for arch in torch.cuda.get_arch_list(): - res = re.search(r'sm_(\d+)', arch) - if res: - arch_cap = res[1] - if int(arch_cap) >= 60: - cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + def __init__(self, name: str, prebuilt_import_path: str): + self.name = name + self.prebuilt_import_path = prebuilt_import_path + self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] - return cc_flag + assert prebuilt_import_path.startswith('colossalai._C'), \ + f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}' + def relative_to_abs_path(self, code_path: str) -> str: + """ + This function takes in a path relative to the colossalai root directory and return the absolute path. + """ + op_builder_module_path = Path(__file__).parent -class Builder(object): - - def colossalai_src_path(self, code_path): - current_file_path = Path(__file__) - if os.path.islink(current_file_path.parent): - # symbolic link - return os.path.join(current_file_path.parent.parent.absolute(), code_path) + # if we install from source + # the current file path will be op_builder/builder.py + # if we install via pip install colossalai + # the current file path will be colossalai/kernel/op_builder/builder.py + # this is because that the op_builder inside colossalai is a symlink + # this symlink will be replaced with actual files if we install via pypi + # thus we cannot tell the colossalai root directory by checking whether the op_builder + # is a symlink, we can only tell whether it is inside or outside colossalai + if str(op_builder_module_path).endswith('colossalai/kernel/op_builder'): + root_path = op_builder_module_path.parent.parent else: - return os.path.join(current_file_path.parent.parent.absolute(), "colossalai", "kernel", code_path) + root_path = op_builder_module_path.parent.joinpath('colossalai') + + code_abs_path = root_path.joinpath(code_path) + return str(code_abs_path) def get_cuda_home_include(self): """ @@ -46,47 +55,94 @@ class Builder(object): cuda_include = os.path.join(CUDA_HOME, "include") return cuda_include + def csrc_abs_path(self, path): + return os.path.join(self.relative_to_abs_path('kernel/cuda_native/csrc'), path) + # functions must be overrided begin - def sources_files(self): + @abstractmethod + def sources_files(self) -> List[str]: + """ + This function should return a list of source files for extensions. + """ raise NotImplementedError - def include_dirs(self): - raise NotImplementedError + @abstractmethod + def include_dirs(self) -> List[str]: + """ + This function should return a list of inlcude files for extensions. + """ + pass - def cxx_flags(self): - raise NotImplementedError + @abstractmethod + def cxx_flags(self) -> List[str]: + """ + This function should return a list of cxx compilation flags for extensions. + """ + pass - def nvcc_flags(self): - raise NotImplementedError + @abstractmethod + def nvcc_flags(self) -> List[str]: + """ + This function should return a list of nvcc compilation flags for extensions. + """ + pass # functions must be overrided over - def strip_empty_entries(self, args): ''' Drop any empty strings from the list of compile and link flags ''' return [x for x in args if len(x) > 0] + def import_op(self): + """ + This function will import the op module by its string name. + """ + return importlib.import_module(self.prebuilt_import_path) + def load(self, verbose=True): """ + load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel. + If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the + kernel is built during pip install, it can be accessed through `colossalai._C`. - load and compile cpu_adam lib at runtime + Warning: do not load this kernel repeatedly during model execution as it could slow down the training process. Args: verbose (bool, optional): show detailed info. Defaults to True. """ - import time - from torch.utils.cpp_extension import load start_build = time.time() - op_module = load(name=self.name, - sources=self.strip_empty_entries(self.sources_files()), - extra_include_paths=self.strip_empty_entries(self.include_dirs()), - extra_cflags=self.cxx_flags(), - extra_cuda_cflags=self.nvcc_flags(), - extra_ldflags=[], - verbose=verbose) + try: + op_module = self.import_op() + if verbose: + print(f"OP {self.prebuilt_import_path} already exists, skip building.") + except ImportError: + # construct the build directory + import torch + torch_version_major = torch.__version__.split('.')[0] + torch_version_minor = torch.__version__.split('.')[1] + torch_cuda_version = torch.version.cuda + home_directory = os.path.expanduser('~') + extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" + build_directory = os.path.join(home_directory, extension_directory) + Path(build_directory).mkdir(parents=True, exist_ok=True) + + if verbose: + print("=========================================================================================") + print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now") + print("=========================================================================================") + + # load the kernel + op_module = load(name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=build_directory, + verbose=verbose) build_duration = time.time() - start_build if verbose: @@ -94,17 +150,16 @@ class Builder(object): return op_module - def builder(self, name) -> 'CUDAExtension': + def builder(self) -> 'CUDAExtension': """ get a CUDAExtension instance used for setup.py """ from torch.utils.cpp_extension import CUDAExtension - return CUDAExtension( - name=name, - sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources_files()], - include_dirs=self.include_dirs(), - extra_compile_args={ - 'cxx': self.cxx_flags(), - 'nvcc': self.nvcc_flags() - }) + return CUDAExtension(name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + 'cxx': self.strip_empty_entries(self.cxx_flags()), + 'nvcc': self.strip_empty_entries(self.nvcc_flags()) + }) diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index 7b5b46319..500e2cc0e 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -6,24 +6,22 @@ from .utils import append_nvcc_threads class CPUAdamBuilder(Builder): NAME = "cpu_adam" - BASE_DIR = "cuda_native" + PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" def __init__(self): - self.name = CPUAdamBuilder.NAME - super().__init__() - + super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] # necessary 4 functions def sources_files(self): ret = [ - os.path.join(CPUAdamBuilder.BASE_DIR, "csrc/cpu_adam.cpp"), + self.csrc_abs_path('cpu_adam.cpp'), ] - return [self.colossalai_src_path(path) for path in ret] + return ret def include_dirs(self): return [ - self.colossalai_src_path(os.path.join(CPUAdamBuilder.BASE_DIR, "includes")), + self.csrc_abs_path("includes"), self.get_cuda_home_include() ] @@ -36,7 +34,5 @@ class CPUAdamBuilder(Builder): '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' ] - - return append_nvcc_threads(['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags) - - # necessary 4 functions + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py index 1f1bb9e11..31ddfced1 100644 --- a/op_builder/fused_optim.py +++ b/op_builder/fused_optim.py @@ -1,20 +1,19 @@ import os -from .builder import Builder, get_cuda_cc_flag +from .builder import Builder +from .utils import get_cuda_cc_flag class FusedOptimBuilder(Builder): - NAME = 'fused_optim' - BASE_DIR = "cuda_native/csrc" + NAME = "fused_optim" + PREBUILT_IMPORT_PATH = "colossalai._C.fused_optim" def __init__(self): - self.name = FusedOptimBuilder.NAME - super().__init__() - self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] - + super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) + def sources_files(self): ret = [ - self.colossalai_src_path(os.path.join(FusedOptimBuilder.BASE_DIR, fname)) for fname in [ + self.csrc_abs_path(fname) for fname in [ 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' ] @@ -22,12 +21,12 @@ class FusedOptimBuilder(Builder): return ret def include_dirs(self): - ret = [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), self.get_cuda_home_include()] - return [self.colossalai_src_path(path) for path in ret] + ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + return ret def cxx_flags(self): - extra_cxx_flags = [] - return ['-O3'] + self.version_dependent_macros + extra_cxx_flags + version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + return ['-O3'] + version_dependent_macros def nvcc_flags(self): extra_cuda_flags = ['-lineinfo'] diff --git a/op_builder/layernorm.py b/op_builder/layernorm.py new file mode 100644 index 000000000..61d941741 --- /dev/null +++ b/op_builder/layernorm.py @@ -0,0 +1,29 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + + +class LayerNormBuilder(Builder): + NAME = "layernorm" + PREBUILT_IMPORT_PATH = "colossalai._C.layernorm" + + def __init__(self): + super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu']] + return ret + + def include_dirs(self): + ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ['-maxrregcount=50'] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + self.version_dependent_macros + return append_nvcc_threads(ret) diff --git a/op_builder/moe.py b/op_builder/moe.py index 5f74e1a72..eeb7d8e39 100644 --- a/op_builder/moe.py +++ b/op_builder/moe.py @@ -1,27 +1,30 @@ import os -from .builder import Builder, get_cuda_cc_flag +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag class MOEBuilder(Builder): + NAME = "moe" + PREBUILT_IMPORT_PATH = "colossalai._C.moe" + def __init__(self): - self.base_dir = "cuda_native/csrc" - self.name = 'moe' - super().__init__() + super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): - ret = [] - ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()] - ret.append(os.path.join(self.base_dir, "kernels", "include")) - return [self.colossalai_src_path(path) for path in ret] + ret = [ + self.csrc_abs_path("kernels/include"), + self.get_cuda_home_include() + ] + return ret def sources_files(self): - ret = [os.path.join(self.base_dir, fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']] - return [self.colossalai_src_path(path) for path in ret] + ret = [self.csrc_abs_path(fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']] + return ret def cxx_flags(self): - return ['-O3', '-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + return ['-O3'] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ @@ -30,4 +33,4 @@ class MOEBuilder(Builder): ] extra_cuda_flags.extend(get_cuda_cc_flag()) ret = ['-O3', '--use_fast_math'] + extra_cuda_flags - return ret + return append_nvcc_threads(ret) diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py index f6eaf6c3d..f9103fe94 100644 --- a/op_builder/multi_head_attn.py +++ b/op_builder/multi_head_attn.py @@ -1,32 +1,32 @@ import os -from .builder import Builder, get_cuda_cc_flag +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag class MultiHeadAttnBuilder(Builder): - def __init__(self): - self.base_dir = "cuda_native/csrc" - self.name = 'multihead_attention' - super().__init__() + NAME = "multihead_attention" + PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" - self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + def __init__(self): + super().__init__(name=MultiHeadAttnBuilder.NAME, + prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) + def include_dirs(self): - ret = [] - ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()] - ret.append(os.path.join(self.base_dir, "kernels", "include")) - return [self.colossalai_src_path(path) for path in ret] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + return ret def sources_files(self): ret = [ - os.path.join(self.base_dir, fname) for fname in [ + self.csrc_abs_path(fname) for fname in [ 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', 'kernels/general_kernels.cu', 'kernels/cuda_util.cu' ] ] - return [self.colossalai_src_path(path) for path in ret] + return ret def cxx_flags(self): return ['-O3'] + self.version_dependent_macros @@ -37,5 +37,5 @@ class MultiHeadAttnBuilder(Builder): '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags - return ret + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/scaled_masked_softmax.py b/op_builder/scaled_masked_softmax.py new file mode 100644 index 000000000..11cfda39a --- /dev/null +++ b/op_builder/scaled_masked_softmax.py @@ -0,0 +1,37 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads + + +class ScaledMaskedSoftmaxBuilder(Builder): + NAME = "scaled_masked_softmax" + PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" + + def __init__(self): + super().__init__(name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in + ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'] + ] + return ret + + def include_dirs(self): + return [ + self.csrc_abs_path("kernels/include"), + self.get_cuda_home_include() + ] + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/scaled_upper_triang_masked_softmax.py b/op_builder/scaled_upper_triang_masked_softmax.py deleted file mode 100644 index c64c6a5e5..000000000 --- a/op_builder/scaled_upper_triang_masked_softmax.py +++ /dev/null @@ -1,36 +0,0 @@ -import os - -from .builder import Builder, get_cuda_cc_flag - - -class ScaledSoftmaxBuilder(Builder): - - def __init__(self): - self.base_dir = "cuda_native/csrc" - self.name = 'scaled_upper_triang_masked_softmax' - super().__init__() - - def include_dirs(self): - ret = [] - ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()] - ret.append(os.path.join(self.base_dir, "kernels", "include")) - return [self.colossalai_src_path(path) for path in ret] - - def sources_files(self): - ret = [ - os.path.join(self.base_dir, fname) - for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'] - ] - return [self.colossalai_src_path(path) for path in ret] - - def cxx_flags(self): - return ['-O3'] - - def nvcc_flags(self): - extra_cuda_flags = [ - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', - '--expt-extended-lambda' - ] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags - return ret diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py new file mode 100644 index 000000000..d0d2433aa --- /dev/null +++ b/op_builder/scaled_upper_triangle_masked_softmax.py @@ -0,0 +1,37 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + + +class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): + NAME = "scaled_upper_triangle_masked_softmax" + PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" + + def __init__(self): + super().__init__(name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + return [ + self.csrc_abs_path("kernels/include"), + self.get_cuda_home_include() + ] + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', + '--expt-extended-lambda' + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/utils.py b/op_builder/utils.py index 757df4efc..b6bada99e 100644 --- a/op_builder/utils.py +++ b/op_builder/utils.py @@ -1,4 +1,6 @@ +import re import subprocess +from typing import List def get_cuda_bare_metal_version(cuda_dir): @@ -11,6 +13,26 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_major, bare_metal_minor +def get_cuda_cc_flag() -> List: + """get_cuda_cc_flag + + cc flag for your GPU arch + """ + + # only import torch when needed + # this is to avoid importing torch when building on a machine without torch pre-installed + # one case is to build wheel for pypi release + import torch + + cc_flag = [] + for arch in torch.cuda.get_arch_list(): + res = re.search(r'sm_(\d+)', arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 60: + cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + + return cc_flag def append_nvcc_threads(nvcc_extra_args): from torch.utils.cpp_extension import CUDA_HOME diff --git a/setup.py b/setup.py index 62cea133f..38d5fa91c 100644 --- a/setup.py +++ b/setup.py @@ -133,59 +133,11 @@ if build_cuda_ext: # and # https://github.com/NVIDIA/apex/issues/456 # https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac - version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] - def cuda_ext_helper(name, sources, extra_cuda_flags, extra_cxx_flags=[]): - return CUDAExtension( - name=name, - sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in sources], - include_dirs=[os.path.join(this_dir, 'colossalai/kernel/cuda_native/csrc/kernels/include')], - extra_compile_args={ - 'cxx': ['-O3'] + version_dependent_macros + extra_cxx_flags, - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags) - }) - - #### fused optim kernels ### - from op_builder import FusedOptimBuilder - ext_modules.append(FusedOptimBuilder().builder('colossalai._C.fused_optim')) - - #### N-D parallel kernels ### - cc_flag = [] - for arch in torch.cuda.get_arch_list(): - res = re.search(r'sm_(\d+)', arch) - if res: - arch_cap = res[1] - if int(arch_cap) >= 60: - cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) - - extra_cuda_flags = [ - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', - '--expt-extended-lambda' - ] - - from op_builder import ScaledSoftmaxBuilder - ext_modules.append(ScaledSoftmaxBuilder().builder('colossalai._C.scaled_upper_triang_masked_softmax')) - - ext_modules.append( - cuda_ext_helper('colossalai._C.scaled_masked_softmax', - ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag)) - - from op_builder import MOEBuilder - ext_modules.append(MOEBuilder().builder('colossalai._C.moe')) - - extra_cuda_flags = ['-maxrregcount=50'] - - ext_modules.append( - cuda_ext_helper('colossalai._C.layer_norm', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'], - extra_cuda_flags + cc_flag)) - - ### MultiHeadAttn Kernel #### - from op_builder import MultiHeadAttnBuilder - ext_modules.append(MultiHeadAttnBuilder().builder('colossalai._C.multihead_attention')) - - ### Gemini Adam kernel #### - from op_builder import CPUAdamBuilder - ext_modules.append(CPUAdamBuilder().builder('colossalai._C.cpu_optim')) + from op_builder import ALL_OPS + for name, builder_cls in ALL_OPS.items(): + print(f'===== Building Extension {name} =====') + ext_modules.append(builder_cls().builder()) setup(name='colossalai', version=get_version(), @@ -227,4 +179,4 @@ setup(name='colossalai', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: System :: Distributed Computing', ], - package_data={'colossalai': ['_C/*.pyi']}) + package_data={'colossalai': ['_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*', 'kernel/cuda_native/csrc/kernels/include/*']}) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py index 9b835af50..d317dc2e3 100644 --- a/tests/test_optimizer/test_cpu_adam.py +++ b/tests/test_optimizer/test_cpu_adam.py @@ -66,7 +66,8 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype): exp_avg_sq = torch.rand(p_data.shape) exp_avg_sq_copy = exp_avg_sq.clone() - from colossalai.kernel import cpu_optim + from colossalai.kernel.op_builder import CPUAdamBuilder + cpu_optim = CPUAdamBuilder().load() cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index f0188e9fa..7b9b6e9c4 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -46,7 +46,8 @@ def torch_adam_update( @parameterize('p_dtype', [torch.float, torch.half]) @parameterize('g_dtype', [torch.float, torch.half]) def test_adam(adamw, step, p_dtype, g_dtype): - from colossalai.kernel import fused_optim + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() fused_adam = fused_optim.multi_tensor_adam dummy_overflow_buf = torch.cuda.IntTensor([0])