diff --git a/colossalai/kernel/op_builder/builder.py b/colossalai/kernel/op_builder/builder.py index 7d1147f97..3c64c3d59 100644 --- a/colossalai/kernel/op_builder/builder.py +++ b/colossalai/kernel/op_builder/builder.py @@ -30,13 +30,31 @@ class Builder(object): else: return os.path.join(Path(__file__).parent.parent.absolute(), code_path) - def get_cuda_include(self): + def get_cuda_home_include(self): + """ + return include path inside the cuda home. + """ from torch.utils.cpp_extension import CUDA_HOME if CUDA_HOME is None: raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") cuda_include = os.path.join(CUDA_HOME, "include") return cuda_include + # functions must be overrided begin + def sources_files(self): + raise NotImplementedError + + def include_dirs(self): + raise NotImplementedError + + def cxx_flags(self): + raise NotImplementedError + + def nvcc_flags(self): + raise NotImplementedError + + # functions must be overrided over + def strip_empty_entries(self, args): ''' Drop any empty strings from the list of compile and link flags @@ -57,10 +75,10 @@ class Builder(object): start_build = time.time() op_module = load(name=self.name, - sources=self.strip_empty_entries(self.sources), - extra_include_paths=self.strip_empty_entries(self.extra_include_paths), - extra_cflags=self.extra_cxx_flags, - extra_cuda_cflags=self.extra_cuda_flags, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), extra_ldflags=[], verbose=verbose) @@ -69,3 +87,18 @@ class Builder(object): print(f"Time to load {self.name} op: {build_duration} seconds") return op_module + + def builder(self, name) -> 'CUDAExtension': + """ + get a CUDAExtension instance used for setup.py + """ + from torch.utils.cpp_extension import CUDAExtension + + return CUDAExtension( + name=name, + sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources_files()], + include_dirs=self.include_dirs(), + extra_compile_args={ + 'cxx': self.cxx_flags(), + 'nvcc': self.nvcc_flags() + }) diff --git a/colossalai/kernel/op_builder/cpu_adam.py b/colossalai/kernel/op_builder/cpu_adam.py index 1fb5adfd6..7b5b46319 100644 --- a/colossalai/kernel/op_builder/cpu_adam.py +++ b/colossalai/kernel/op_builder/cpu_adam.py @@ -12,68 +12,31 @@ class CPUAdamBuilder(Builder): self.name = CPUAdamBuilder.NAME super().__init__() - self.sources = [self.colossalai_src_path(path) for path in self.sources_files()] - self.extra_include_paths = [self.colossalai_src_path(path) for path in self.include_paths()] - self.extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] - self.extra_cuda_flags = [ + self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + + # necessary 4 functions + def sources_files(self): + ret = [ + os.path.join(CPUAdamBuilder.BASE_DIR, "csrc/cpu_adam.cpp"), + ] + return [self.colossalai_src_path(path) for path in ret] + + def include_dirs(self): + return [ + self.colossalai_src_path(os.path.join(CPUAdamBuilder.BASE_DIR, "includes")), + self.get_cuda_home_include() + ] + + def cxx_flags(self): + extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] + return ['-O3'] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + extra_cuda_flags = [ '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' ] - self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] - def sources_files(self): - return [ - os.path.join(CPUAdamBuilder.BASE_DIR, "csrc/cpu_adam.cpp"), - ] + return append_nvcc_threads(['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags) - def include_paths(self): - return [os.path.join(CPUAdamBuilder.BASE_DIR, "includes"), self.get_cuda_include()] - - def strip_empty_entries(self, args): - ''' - Drop any empty strings from the list of compile and link flags - ''' - return [x for x in args if len(x) > 0] - - def builder(self, name) -> 'CUDAExtension': - """ - get a CUDAExtension instance used for setup.py - """ - from torch.utils.cpp_extension import CUDAExtension - - return CUDAExtension( - name=name, - sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources], - include_dirs=self.extra_include_paths, - extra_compile_args={ - 'cxx': ['-O3'] + self.version_dependent_macros + self.extra_cxx_flags, - 'nvcc': - append_nvcc_threads(['-O3', '--use_fast_math'] + self.version_dependent_macros + - self.extra_cuda_flags) - }) - - def load(self, verbose=True): - """ - load and compile cpu_adam lib at runtime - - Args: - verbose (bool, optional): show detailed info. Defaults to True. - """ - import time - - from torch.utils.cpp_extension import load - start_build = time.time() - - op_module = load(name=self.name, - sources=self.strip_empty_entries(self.sources), - extra_include_paths=self.strip_empty_entries(self.extra_include_paths), - extra_cflags=self.extra_cxx_flags, - extra_cuda_cflags=self.extra_cuda_flags, - extra_ldflags=[], - verbose=verbose) - - build_duration = time.time() - start_build - if verbose: - print(f"Time to load {self.name} op: {build_duration} seconds") - - return op_module + # necessary 4 functions diff --git a/colossalai/kernel/op_builder/fused_optim.py b/colossalai/kernel/op_builder/fused_optim.py index 8bfcf3471..1f1bb9e11 100644 --- a/colossalai/kernel/op_builder/fused_optim.py +++ b/colossalai/kernel/op_builder/fused_optim.py @@ -1,7 +1,4 @@ import os -import re - -import torch from .builder import Builder, get_cuda_cc_flag @@ -13,33 +10,26 @@ class FusedOptimBuilder(Builder): def __init__(self): self.name = FusedOptimBuilder.NAME super().__init__() - - self.extra_cxx_flags = [] - self.extra_cuda_flags = ['-lineinfo'] - self.extra_cuda_flags.extend(get_cuda_cc_flag()) - - self.sources = [self.colossalai_src_path(path) for path in self.sources_files()] - self.extra_include_paths = [self.colossalai_src_path(path) for path in self.include_paths()] self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] def sources_files(self): - return [ - os.path.join(FusedOptimBuilder.BASE_DIR, fname) for fname in [ + ret = [ + self.colossalai_src_path(os.path.join(FusedOptimBuilder.BASE_DIR, fname)) for fname in [ 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' ] ] + return ret - def include_paths(self): - return [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), self.get_cuda_include()] + def include_dirs(self): + ret = [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), self.get_cuda_home_include()] + return [self.colossalai_src_path(path) for path in ret] - def builder(self, name): - from torch.utils.cpp_extension import CUDAExtension - return CUDAExtension( - name=name, - sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources], - include_dirs=self.extra_include_paths, - extra_compile_args={ - 'cxx': ['-O3'] + self.version_dependent_macros + self.extra_cxx_flags, - 'nvcc': ['-O3', '--use_fast_math'] + self.extra_cuda_flags - }) + def cxx_flags(self): + extra_cxx_flags = [] + return ['-O3'] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + extra_cuda_flags = ['-lineinfo'] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ['-O3', '--use_fast_math'] + extra_cuda_flags diff --git a/colossalai/kernel/op_builder/multi_head_attn.py b/colossalai/kernel/op_builder/multi_head_attn.py index b83b193a6..f6eaf6c3d 100644 --- a/colossalai/kernel/op_builder/multi_head_attn.py +++ b/colossalai/kernel/op_builder/multi_head_attn.py @@ -9,41 +9,33 @@ class MultiHeadAttnBuilder(Builder): self.base_dir = "cuda_native/csrc" self.name = 'multihead_attention' super().__init__() - self.extra_cxx_flags = [] - self.extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' - ] - - self.extra_cuda_flags.extend(get_cuda_cc_flag()) - self.sources = [self.colossalai_src_path(path) for path in self.sources_files()] - self.extra_include_paths = [self.colossalai_src_path(path) for path in self.include_paths()] self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + def include_dirs(self): + ret = [] + ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()] + ret.append(os.path.join(self.base_dir, "kernels", "include")) + return [self.colossalai_src_path(path) for path in ret] + def sources_files(self): - return [ + ret = [ os.path.join(self.base_dir, fname) for fname in [ 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', 'kernels/general_kernels.cu', 'kernels/cuda_util.cu' ] ] + return [self.colossalai_src_path(path) for path in ret] - def include_paths(self): - ret = [] - ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_include()] - ret.append(os.path.join(self.base_dir, "kernels", "include")) - print("include_paths", ret) + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + extra_cuda_flags return ret - - def builder(self, name): - from torch.utils.cpp_extension import CUDAExtension - return CUDAExtension( - name=name, - sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources], - include_dirs=self.extra_include_paths, - extra_compile_args={ - 'cxx': ['-O3'] + self.version_dependent_macros, - 'nvcc': ['-O3', '--use_fast_math'] + self.extra_cuda_flags - })