[setup] support pre-build and jit-build of cuda kernels (#2374)

* [setup] support pre-build and jit-build of cuda kernels

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
pull/2389/head
Frank Lee 2023-01-06 20:50:26 +08:00 committed by GitHub
parent 12c8bf38d7
commit 40d376c566
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 414 additions and 390 deletions

7
.gitignore vendored
View File

@ -144,3 +144,10 @@ docs/.build
# ignore version.py generated by setup.py
colossalai/version.py
# ignore any kernel build files
.o
.so
# ignore python interface defition file
.pyi

View File

View File

@ -1,9 +0,0 @@
from . import (
cpu_optim,
fused_optim,
layer_norm,
moe,
multihead_attention,
scaled_masked_softmax,
scaled_upper_triang_masked_softmax,
)

View File

@ -1,8 +0,0 @@
from torch import Tensor
class CPUAdamOptimizer:
def __init__(self, lr: float, beta1: float, beta2: float, eps: float,
weight_decay: float, adamw_mode: float) -> None: ...
def step(self, step: int, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, bias_correction: bool,
param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor, loss_scale: float) -> None: ...

View File

@ -1,23 +0,0 @@
from typing import List
from torch import Tensor
def multi_tensor_scale(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], scale: float) -> None:
...
def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], weight_decay: float,
momentum: float, dampening: float, lr: float, nesterov: bool, first_run: bool, weight_decay_after_momentum: bool, scale: float) -> None:
...
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float, div_scale: float) -> None:
...
def multi_tensor_lamb(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, bias_correction: int, weight_decay: float, grad_averaging: int, mode: int, global_grad_norm: Tensor, max_grad_norm: float, use_nvlamb_python: bool) -> None:
...
def multi_tensor_l2norm(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], per_tensor_python: bool) -> None:
...

View File

@ -1,11 +0,0 @@
from typing import List
from torch import Tensor
def forward_affine(input: Tensor, normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
...
def backward_affine(dout: Tensor, mean: Tensor, invvar: Tensor, input: Tensor,
normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
...

View File

@ -1,20 +0,0 @@
from torch import Tensor
def cumsum_sub_one(mask: Tensor) -> Tensor:
...
def dispatch_forward(s: int, ec: int, h: int, batch_tokens: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...
def dispatch_backward(s: int, ec: int, h: int, expert_grad: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...
def combine_forward(s: int, e: int, c: int, h: int, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...
def combine_backward(s: int, e: int, c: int, h: int, tokens_grad: Tensor, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...

View File

@ -1,55 +0,0 @@
from typing import List
from torch import Tensor
from torch.distributed import ProcessGroup
def multihead_attention_fw_fp32(layer_id: int, input: Tensor, input_mask: Tensor,
in_proj_weight: Tensor, in_proj_bias: Tensor,
out_proj_weight: Tensor, out_proj_bias: Tensor,
norm_weight: Tensor, norm_bias: Tensor,
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
...
def multihead_attention_fw_fp16(layer_id: int, input: Tensor, input_mask: Tensor,
in_proj_weight: Tensor, in_proj_bias: Tensor,
out_proj_weight: Tensor, out_proj_bias: Tensor,
norm_weight: Tensor, norm_bias: Tensor,
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
...
def multihead_attention_bw_fp32(layer_id: int, grad_dec_output: Tensor,
output: Tensor, input: Tensor,
input_mask: Tensor, in_proj_weight: Tensor,
in_proj_bias: Tensor, out_proj_weight: Tensor,
out_proj_bias: Tensor, norm_weight: Tensor,
norm_bias: Tensor) -> List[Tensor]:
...
def multihead_attention_bw_fp16(layer_id: int, grad_dec_output: Tensor,
output: Tensor, input: Tensor,
input_mask: Tensor, in_proj_weight: Tensor,
in_proj_bias: Tensor, out_proj_weight: Tensor,
out_proj_bias: Tensor, norm_weight: Tensor,
norm_bias: Tensor) -> List[Tensor]:
...
def create_multihead_attention_fp32(layer_id: int, max_batch_tokens: int,
max_seq_len: int, hidden_dim: int, num_heads: int,
attn_prob_dropout_ratio: float,
hidden_dropout_ratio: float,
pre_or_postLayerNorm: bool,
pg: ProcessGroup) -> int:
...
def create_multihead_attention_fp16(layer_id: int, max_batch_tokens: int,
max_seq_len: int, hidden_dim: int, num_heads: int,
attn_prob_dropout_ratio: float,
hidden_dropout_ratio: float,
pre_or_postLayerNorm: bool,
pg: ProcessGroup) -> int:
...

View File

@ -1,12 +0,0 @@
from torch import Tensor
def forward(input: Tensor, mask: Tensor, scale: float) -> Tensor:
...
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
...
def get_batch_per_block(query_seq_len: int, key_seq_len: int, batches: int, attn_heads: int) -> int:
...

View File

@ -1,8 +0,0 @@
from torch import Tensor
def forward(input: Tensor, scale: float) -> Tensor:
...
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
...

View File

@ -8,16 +8,28 @@ from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.kernel import fused_optim
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.logging import get_dist_logger
from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier
from ._utils import has_inf_or_nan, zero_gard_by_list
from .grad_scaler import BaseGradScaler
try:
from colossalai._C import fused_optim
except:
fused_optim = None
__all__ = ['FP16Optimizer']
def load_fused_optim():
global fused_optim
if fused_optim is None:
fused_optim = FusedOptimBuilder().load()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
@ -30,6 +42,8 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
if overflow_buf:
overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
global fused_optim
load_fused_optim()
multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
else:
for this_, that_ in zip(this, that):

View File

@ -1,42 +1,7 @@
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
try:
from colossalai._C import fused_optim
except:
from colossalai.kernel.op_builder.fused_optim import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
try:
from colossalai._C import cpu_optim
except ImportError:
from colossalai.kernel.op_builder import CPUAdamBuilder
cpu_optim = CPUAdamBuilder().load()
try:
from colossalai._C import multihead_attention
except ImportError:
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
multihead_attention = MultiHeadAttnBuilder().load()
try:
from colossalai._C import scaled_upper_triang_masked_softmax
except ImportError:
from colossalai.kernel.op_builder import ScaledSoftmaxBuilder
scaled_upper_triang_masked_softmax = ScaledSoftmaxBuilder().load()
try:
from colossalai._C import moe
except ImportError:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
__all__ = [
"fused_optim",
"cpu_optim",
"multihead_attention",
"moe",
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
"scaled_upper_triang_masked_softmax",
]

View File

@ -135,7 +135,8 @@ class MultiHeadAttention(nn.Module):
# Load cuda modules if needed
global colossal_multihead_attention
if colossal_multihead_attention is None:
from colossalai.kernel import multihead_attention
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
multihead_attention = MultiHeadAttnBuilder().load()
colossal_multihead_attention = multihead_attention
# create the layer in cuda kernels.

View File

@ -6,13 +6,32 @@ from torch import Tensor
from torch.distributed import ProcessGroup
COL_MOE_KERNEL_FLAG = False
from colossalai.kernel import moe
try:
from colossalai._C import moe
except:
moe = None
def build_moe_if_not_prebuilt():
# load moe kernel during runtime if not pre-built
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
if ctx is not None:
ctx.comm_grp = group
@ -85,6 +104,9 @@ class MoeDispatch(torch.autograd.Function):
s = tokens.size(0)
h = tokens.size(1)
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
ctx.save_for_backward(mask, dest_idx)
@ -112,6 +134,9 @@ class MoeCombine(torch.autograd.Function):
c = ec // e
h = expert_tokens.size(-1)
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
fp16_flag = (expert_tokens.dtype == torch.float16)
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
@ -143,6 +168,8 @@ def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and COL_MOE_KERNEL_FLAG:
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
return moe.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1

View File

@ -3,6 +3,7 @@ from typing import Optional
import torch
from colossalai.kernel.op_builder import CPUAdamBuilder
from colossalai.registry import OPTIMIZERS
from .nvme_optimizer import NVMeOptimizer
@ -76,12 +77,8 @@ class CPUAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
try:
import colossalai._C.cpu_optim
except ImportError:
raise ImportError('Please install colossalai from source code to use CPUAdam')
self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
adamw_mode)
cpu_adam = CPUAdamBuilder().load()
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def torch_adam_update(self,
data,

View File

@ -65,7 +65,8 @@ class FusedAdam(torch.optim.Optimizer):
self.adamw_mode = 1 if adamw_mode else 0
self.set_grad_none = set_grad_none
if multi_tensor_applier.available:
from colossalai.kernel import fused_optim
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])

View File

@ -76,7 +76,8 @@ class FusedLAMB(torch.optim.Optimizer):
max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
from colossalai.kernel import fused_optim
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm
# Skip buffer

View File

@ -80,7 +80,8 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum
if multi_tensor_applier.available:
from colossalai.kernel import fused_optim
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0],

View File

@ -2,6 +2,7 @@ from typing import Any, Optional
import torch
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
@ -77,7 +78,9 @@ class HybridAdam(NVMeOptimizer):
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
from colossalai.kernel import cpu_optim, fused_optim
# build during runtime if not found
cpu_optim = CPUAdamBuilder().load()
fused_optim = FusedOptimBuilder().load()
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
self.gpu_adam_op = fused_optim.multi_tensor_adam

View File

@ -18,11 +18,15 @@ from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARA
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import fused_optim
from colossalai.tensor import ColoParameter, ProcessGroup
from .multi_tensor_apply import multi_tensor_applier
try:
from colossalai._C import fused_optim
except:
fused_optim = None
def print_rank_0(msg: str, logger=None):
"""Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
@ -123,6 +127,13 @@ def is_model_parallel_parameter(p):
def _calc_l2_norm(grads):
# we should not
global fused_optim
if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
norm = 0.0
if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])

View File

@ -14,7 +14,6 @@ class MultiTensorApply(object):
def __init__(self, chunk_size):
try:
from colossalai.kernel import fused_optim
MultiTensorApply.available = True
self.chunk_size = chunk_size
except ImportError as err:

31
op_builder/README.md Normal file
View File

@ -0,0 +1,31 @@
# Build PyTorch Extensions
## Overview
Building PyTorch extensions can be a difficult task for users not from the system background. It is definitely frustrating if the users encounter many strange technical jargons when install Colossal-AI. Therefore, we will provide two methods of building the PyTorch extensions for the users.
1. Build CUDA extensions when running `pip install` if `CUDA_EXT=1`
2. Build the extension during runtime
The first method is more suitable for users who are familiar with CUDA environment configurations. The second method is for those who are not as they only need to build the kernel which is required by their program.
These two methods have different advantages and disadvantages.
Method 1 is good because it allows the user to build all kernels during installation and directly import the kernel. They don't need to care about kernel building when running their program. However, installation may fail if they don't know how to configure their environments and this leads to much frustration.
Method 2 is good because it allows the user to only build the kernel they actually need, such that there is a lower probability that they encounter environment issue. However, it may slow down their program due to the first build and subsequence load.
## PyTorch Extensions in Colossal-AI
As mentioned in the section above, our aim is to make these two methods coherently supported in Colossal-AI, meaning that for a kernel should be either built in `setup.py` or during runtime.
There are mainly two functions used to build extensions.
1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`.
2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime
Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong).
We have implemented the following conventions:
1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C`
2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete)
When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered.

View File

@ -1,7 +1,23 @@
from .cpu_adam import CPUAdamBuilder
from .fused_optim import FusedOptimBuilder
from .layernorm import LayerNormBuilder
from .moe import MOEBuilder
from .multi_head_attn import MultiHeadAttnBuilder
from .scaled_upper_triang_masked_softmax import ScaledSoftmaxBuilder
from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder
__all__ = ['CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledSoftmaxBuilder', 'MOEBuilder']
ALL_OPS = {
'cpu_adam': CPUAdamBuilder,
'fused_optim': FusedOptimBuilder,
'moe': MOEBuilder,
'multi_head_attn': MultiHeadAttnBuilder,
'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder,
'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder,
'layernorm': LayerNormBuilder,
}
__all__ = [
'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder',
'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder',
'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder'
]

View File

@ -1,40 +1,49 @@
import importlib
import os
import re
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List
def get_cuda_cc_flag() -> List:
"""get_cuda_cc_flag
class Builder(ABC):
"""
Builder is the base class to build extensions for PyTorch.
cc flag for your GPU arch
Args:
name (str): the name of the kernel to be built
prebuilt_import_path (str): the path where the extension is installed during pip install
"""
# only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release
import torch
cc_flag = []
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60:
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
def __init__(self, name: str, prebuilt_import_path: str):
self.name = name
self.prebuilt_import_path = prebuilt_import_path
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
return cc_flag
assert prebuilt_import_path.startswith('colossalai._C'), \
f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}'
def relative_to_abs_path(self, code_path: str) -> str:
"""
This function takes in a path relative to the colossalai root directory and return the absolute path.
"""
op_builder_module_path = Path(__file__).parent
class Builder(object):
def colossalai_src_path(self, code_path):
current_file_path = Path(__file__)
if os.path.islink(current_file_path.parent):
# symbolic link
return os.path.join(current_file_path.parent.parent.absolute(), code_path)
# if we install from source
# the current file path will be op_builder/builder.py
# if we install via pip install colossalai
# the current file path will be colossalai/kernel/op_builder/builder.py
# this is because that the op_builder inside colossalai is a symlink
# this symlink will be replaced with actual files if we install via pypi
# thus we cannot tell the colossalai root directory by checking whether the op_builder
# is a symlink, we can only tell whether it is inside or outside colossalai
if str(op_builder_module_path).endswith('colossalai/kernel/op_builder'):
root_path = op_builder_module_path.parent.parent
else:
return os.path.join(current_file_path.parent.parent.absolute(), "colossalai", "kernel", code_path)
root_path = op_builder_module_path.parent.joinpath('colossalai')
code_abs_path = root_path.joinpath(code_path)
return str(code_abs_path)
def get_cuda_home_include(self):
"""
@ -46,47 +55,94 @@ class Builder(object):
cuda_include = os.path.join(CUDA_HOME, "include")
return cuda_include
def csrc_abs_path(self, path):
return os.path.join(self.relative_to_abs_path('kernel/cuda_native/csrc'), path)
# functions must be overrided begin
def sources_files(self):
@abstractmethod
def sources_files(self) -> List[str]:
"""
This function should return a list of source files for extensions.
"""
raise NotImplementedError
def include_dirs(self):
raise NotImplementedError
@abstractmethod
def include_dirs(self) -> List[str]:
"""
This function should return a list of inlcude files for extensions.
"""
pass
def cxx_flags(self):
raise NotImplementedError
@abstractmethod
def cxx_flags(self) -> List[str]:
"""
This function should return a list of cxx compilation flags for extensions.
"""
pass
def nvcc_flags(self):
raise NotImplementedError
@abstractmethod
def nvcc_flags(self) -> List[str]:
"""
This function should return a list of nvcc compilation flags for extensions.
"""
pass
# functions must be overrided over
def strip_empty_entries(self, args):
'''
Drop any empty strings from the list of compile and link flags
'''
return [x for x in args if len(x) > 0]
def import_op(self):
"""
This function will import the op module by its string name.
"""
return importlib.import_module(self.prebuilt_import_path)
def load(self, verbose=True):
"""
load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel.
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
kernel is built during pip install, it can be accessed through `colossalai._C`.
load and compile cpu_adam lib at runtime
Warning: do not load this kernel repeatedly during model execution as it could slow down the training process.
Args:
verbose (bool, optional): show detailed info. Defaults to True.
"""
import time
from torch.utils.cpp_extension import load
start_build = time.time()
op_module = load(name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_cuda_cflags=self.nvcc_flags(),
extra_ldflags=[],
verbose=verbose)
try:
op_module = self.import_op()
if verbose:
print(f"OP {self.prebuilt_import_path} already exists, skip building.")
except ImportError:
# construct the build directory
import torch
torch_version_major = torch.__version__.split('.')[0]
torch_version_minor = torch.__version__.split('.')[1]
torch_cuda_version = torch.version.cuda
home_directory = os.path.expanduser('~')
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}"
build_directory = os.path.join(home_directory, extension_directory)
Path(build_directory).mkdir(parents=True, exist_ok=True)
if verbose:
print("=========================================================================================")
print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
print("=========================================================================================")
# load the kernel
op_module = load(name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_cuda_cflags=self.nvcc_flags(),
extra_ldflags=[],
build_directory=build_directory,
verbose=verbose)
build_duration = time.time() - start_build
if verbose:
@ -94,17 +150,16 @@ class Builder(object):
return op_module
def builder(self, name) -> 'CUDAExtension':
def builder(self) -> 'CUDAExtension':
"""
get a CUDAExtension instance used for setup.py
"""
from torch.utils.cpp_extension import CUDAExtension
return CUDAExtension(
name=name,
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources_files()],
include_dirs=self.include_dirs(),
extra_compile_args={
'cxx': self.cxx_flags(),
'nvcc': self.nvcc_flags()
})
return CUDAExtension(name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args={
'cxx': self.strip_empty_entries(self.cxx_flags()),
'nvcc': self.strip_empty_entries(self.nvcc_flags())
})

View File

@ -6,24 +6,22 @@ from .utils import append_nvcc_threads
class CPUAdamBuilder(Builder):
NAME = "cpu_adam"
BASE_DIR = "cuda_native"
PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam"
def __init__(self):
self.name = CPUAdamBuilder.NAME
super().__init__()
super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
# necessary 4 functions
def sources_files(self):
ret = [
os.path.join(CPUAdamBuilder.BASE_DIR, "csrc/cpu_adam.cpp"),
self.csrc_abs_path('cpu_adam.cpp'),
]
return [self.colossalai_src_path(path) for path in ret]
return ret
def include_dirs(self):
return [
self.colossalai_src_path(os.path.join(CPUAdamBuilder.BASE_DIR, "includes")),
self.csrc_abs_path("includes"),
self.get_cuda_home_include()
]
@ -36,7 +34,5 @@ class CPUAdamBuilder(Builder):
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
]
return append_nvcc_threads(['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags)
# necessary 4 functions
ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -1,20 +1,19 @@
import os
from .builder import Builder, get_cuda_cc_flag
from .builder import Builder
from .utils import get_cuda_cc_flag
class FusedOptimBuilder(Builder):
NAME = 'fused_optim'
BASE_DIR = "cuda_native/csrc"
NAME = "fused_optim"
PREBUILT_IMPORT_PATH = "colossalai._C.fused_optim"
def __init__(self):
self.name = FusedOptimBuilder.NAME
super().__init__()
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH)
def sources_files(self):
ret = [
self.colossalai_src_path(os.path.join(FusedOptimBuilder.BASE_DIR, fname)) for fname in [
self.csrc_abs_path(fname) for fname in [
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu'
]
@ -22,12 +21,12 @@ class FusedOptimBuilder(Builder):
return ret
def include_dirs(self):
ret = [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), self.get_cuda_home_include()]
return [self.colossalai_src_path(path) for path in ret]
ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()]
return ret
def cxx_flags(self):
extra_cxx_flags = []
return ['-O3'] + self.version_dependent_macros + extra_cxx_flags
version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
return ['-O3'] + version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ['-lineinfo']

29
op_builder/layernorm.py Normal file
View File

@ -0,0 +1,29 @@
import os
from .builder import Builder
from .utils import append_nvcc_threads, get_cuda_cc_flag
class LayerNormBuilder(Builder):
NAME = "layernorm"
PREBUILT_IMPORT_PATH = "colossalai._C.layernorm"
def __init__(self):
super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH)
def sources_files(self):
ret = [self.csrc_abs_path(fname) for fname in ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu']]
return ret
def include_dirs(self):
ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()]
return ret
def cxx_flags(self):
return ['-O3'] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ['-maxrregcount=50']
extra_cuda_flags.extend(get_cuda_cc_flag())
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + self.version_dependent_macros
return append_nvcc_threads(ret)

View File

@ -1,27 +1,30 @@
import os
from .builder import Builder, get_cuda_cc_flag
from .builder import Builder
from .utils import append_nvcc_threads, get_cuda_cc_flag
class MOEBuilder(Builder):
NAME = "moe"
PREBUILT_IMPORT_PATH = "colossalai._C.moe"
def __init__(self):
self.base_dir = "cuda_native/csrc"
self.name = 'moe'
super().__init__()
super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH)
def include_dirs(self):
ret = []
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()]
ret.append(os.path.join(self.base_dir, "kernels", "include"))
return [self.colossalai_src_path(path) for path in ret]
ret = [
self.csrc_abs_path("kernels/include"),
self.get_cuda_home_include()
]
return ret
def sources_files(self):
ret = [os.path.join(self.base_dir, fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']]
return [self.colossalai_src_path(path) for path in ret]
ret = [self.csrc_abs_path(fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']]
return ret
def cxx_flags(self):
return ['-O3', '-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
return ['-O3'] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = [
@ -30,4 +33,4 @@ class MOEBuilder(Builder):
]
extra_cuda_flags.extend(get_cuda_cc_flag())
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
return ret
return append_nvcc_threads(ret)

View File

@ -1,32 +1,32 @@
import os
from .builder import Builder, get_cuda_cc_flag
from .builder import Builder
from .utils import append_nvcc_threads, get_cuda_cc_flag
class MultiHeadAttnBuilder(Builder):
def __init__(self):
self.base_dir = "cuda_native/csrc"
self.name = 'multihead_attention'
super().__init__()
NAME = "multihead_attention"
PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention"
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
def __init__(self):
super().__init__(name=MultiHeadAttnBuilder.NAME,
prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH)
def include_dirs(self):
ret = []
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()]
ret.append(os.path.join(self.base_dir, "kernels", "include"))
return [self.colossalai_src_path(path) for path in ret]
ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()]
return ret
def sources_files(self):
ret = [
os.path.join(self.base_dir, fname) for fname in [
self.csrc_abs_path(fname) for fname in [
'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu',
'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu',
'kernels/general_kernels.cu', 'kernels/cuda_util.cu'
]
]
return [self.colossalai_src_path(path) for path in ret]
return ret
def cxx_flags(self):
return ['-O3'] + self.version_dependent_macros
@ -37,5 +37,5 @@ class MultiHeadAttnBuilder(Builder):
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
]
extra_cuda_flags.extend(get_cuda_cc_flag())
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
return ret
ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -0,0 +1,37 @@
import os
from .builder import Builder
from .utils import append_nvcc_threads
class ScaledMaskedSoftmaxBuilder(Builder):
NAME = "scaled_masked_softmax"
PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax"
def __init__(self):
super().__init__(name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH)
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path(fname) for fname in
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu']
]
return ret
def include_dirs(self):
return [
self.csrc_abs_path("kernels/include"),
self.get_cuda_home_include()
]
def cxx_flags(self):
return ['-O3'] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = [
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
]
ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -1,36 +0,0 @@
import os
from .builder import Builder, get_cuda_cc_flag
class ScaledSoftmaxBuilder(Builder):
def __init__(self):
self.base_dir = "cuda_native/csrc"
self.name = 'scaled_upper_triang_masked_softmax'
super().__init__()
def include_dirs(self):
ret = []
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()]
ret.append(os.path.join(self.base_dir, "kernels", "include"))
return [self.colossalai_src_path(path) for path in ret]
def sources_files(self):
ret = [
os.path.join(self.base_dir, fname)
for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu']
]
return [self.colossalai_src_path(path) for path in ret]
def cxx_flags(self):
return ['-O3']
def nvcc_flags(self):
extra_cuda_flags = [
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr',
'--expt-extended-lambda'
]
extra_cuda_flags.extend(get_cuda_cc_flag())
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
return ret

View File

@ -0,0 +1,37 @@
import os
from .builder import Builder
from .utils import append_nvcc_threads, get_cuda_cc_flag
class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder):
NAME = "scaled_upper_triangle_masked_softmax"
PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax"
def __init__(self):
super().__init__(name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH)
def include_dirs(self):
return [
self.csrc_abs_path("kernels/include"),
self.get_cuda_home_include()
]
def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu']
]
return ret
def cxx_flags(self):
return ['-O3'] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = [
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr',
'--expt-extended-lambda'
]
extra_cuda_flags.extend(get_cuda_cc_flag())
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -1,4 +1,6 @@
import re
import subprocess
from typing import List
def get_cuda_bare_metal_version(cuda_dir):
@ -11,6 +13,26 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_major, bare_metal_minor
def get_cuda_cc_flag() -> List:
"""get_cuda_cc_flag
cc flag for your GPU arch
"""
# only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release
import torch
cc_flag = []
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60:
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
return cc_flag
def append_nvcc_threads(nvcc_extra_args):
from torch.utils.cpp_extension import CUDA_HOME

View File

@ -133,59 +133,11 @@ if build_cuda_ext:
# and
# https://github.com/NVIDIA/apex/issues/456
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
def cuda_ext_helper(name, sources, extra_cuda_flags, extra_cxx_flags=[]):
return CUDAExtension(
name=name,
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in sources],
include_dirs=[os.path.join(this_dir, 'colossalai/kernel/cuda_native/csrc/kernels/include')],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros + extra_cxx_flags,
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)
})
#### fused optim kernels ###
from op_builder import FusedOptimBuilder
ext_modules.append(FusedOptimBuilder().builder('colossalai._C.fused_optim'))
#### N-D parallel kernels ###
cc_flag = []
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60:
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
extra_cuda_flags = [
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr',
'--expt-extended-lambda'
]
from op_builder import ScaledSoftmaxBuilder
ext_modules.append(ScaledSoftmaxBuilder().builder('colossalai._C.scaled_upper_triang_masked_softmax'))
ext_modules.append(
cuda_ext_helper('colossalai._C.scaled_masked_softmax',
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag))
from op_builder import MOEBuilder
ext_modules.append(MOEBuilder().builder('colossalai._C.moe'))
extra_cuda_flags = ['-maxrregcount=50']
ext_modules.append(
cuda_ext_helper('colossalai._C.layer_norm', ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'],
extra_cuda_flags + cc_flag))
### MultiHeadAttn Kernel ####
from op_builder import MultiHeadAttnBuilder
ext_modules.append(MultiHeadAttnBuilder().builder('colossalai._C.multihead_attention'))
### Gemini Adam kernel ####
from op_builder import CPUAdamBuilder
ext_modules.append(CPUAdamBuilder().builder('colossalai._C.cpu_optim'))
from op_builder import ALL_OPS
for name, builder_cls in ALL_OPS.items():
print(f'===== Building Extension {name} =====')
ext_modules.append(builder_cls().builder())
setup(name='colossalai',
version=get_version(),
@ -227,4 +179,4 @@ setup(name='colossalai',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: System :: Distributed Computing',
],
package_data={'colossalai': ['_C/*.pyi']})
package_data={'colossalai': ['_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*', 'kernel/cuda_native/csrc/kernels/include/*']})

View File

@ -66,7 +66,8 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
exp_avg_sq = torch.rand(p_data.shape)
exp_avg_sq_copy = exp_avg_sq.clone()
from colossalai.kernel import cpu_optim
from colossalai.kernel.op_builder import CPUAdamBuilder
cpu_optim = CPUAdamBuilder().load()
cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)

View File

@ -46,7 +46,8 @@ def torch_adam_update(
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, step, p_dtype, g_dtype):
from colossalai.kernel import fused_optim
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
fused_adam = fused_optim.multi_tensor_adam
dummy_overflow_buf = torch.cuda.IntTensor([0])