diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 8933fc0a3..5356fbf48 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,7 +1,14 @@ +from .cpu_adam_loader import CPUAdamLoader from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention +from .extensions.flash_attention import AttnMaskType +from .flash_attention_loader import ColoAttention, FlashAttentionLoader __all__ = [ "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention", + "CPUAdamLoader", + "FlashAttentionLoader", + "ColoAttention", + "AttnMaskType", ] diff --git a/colossalai/kernel/base_kernel_loader.py b/colossalai/kernel/base_kernel_loader.py new file mode 100644 index 000000000..ff7a43261 --- /dev/null +++ b/colossalai/kernel/base_kernel_loader.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod +from typing import Dict, List + +from .extensions.base_extension import BaseExtension + + +class BaseKernelLoader(ABC): + """ + Usage: + kernel_loader = KernelLoader() + kernel = kernel_loader.load() + """ + + def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]): + self._extension_map = extension_map + self._supported_device = supported_device + + def run_checks(self): + # run supported device check and other possible checks + pass + + @abstractmethod + def fetch_kernel(self): + pass + + def load(self): + self.run_checks() + return self.fetch_kernel() diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py new file mode 100644 index 000000000..0df6bd49b --- /dev/null +++ b/colossalai/kernel/cpu_adam_loader.py @@ -0,0 +1,64 @@ +import platform +from collections import OrderedDict + +from .base_kernel_loader import BaseKernelLoader +from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension + + +class CPUAdamLoader(BaseKernelLoader): + """ + CPU Adam Loader + + Usage: + # init + cpu_adam = CPUAdamLoader().load() + cpu_adam_op = cpu_adam.CPUAdamOptimizer( + alpha, beta1, beta2, epsilon, weight_decay, adamw_mode, + ) + ... + # optim step + cpu_adam_op.step( + step, lr, beta1, beta2, epsilon, weight_decay, bias_correction, + params, grads, exp_avg, exp_avg_sq, loss_scale, + ) + + Args: + func CPUAdamOptimizer: + alpha (float): learning rate. Default to 1e-3. + beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9. + beta2 (float): coefficients used for computing running averages of its square. Default to 0.99. + epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8. + weight_decay (float): weight decay (L2 penalty). Default to 0. + adamw_mode (bool): whether to use the adamw. Default to True. + func step: + step (int): current step. + lr (float): learning rate. + beta1 (float): coefficients used for computing running averages of gradient. + beta2 (float): coefficients used for computing running averages of its square. + epsilon (float): term added to the denominator to improve numerical stability. + weight_decay (float): weight decay (L2 penalty). + bias_correction (bool): whether to use bias correction. + params (torch.Tensor): parameter. + grads (torch.Tensor): gradient. + exp_avg (torch.Tensor): exp average. + exp_avg_sq (torch.Tensor): exp average square. + loss_scale (float): loss scale value. + """ + + def __init__(self): + super().__init__( + extension_map=OrderedDict( + arm=ArmCPUAdamExtension, + x86=X86CPUAdamExtension, + ), + supported_device=["cpu"], + ) + + def fetch_kernel(self): + if platform.machine() == "x86_64": + kernel = self._extension_map["x86"]().fetch() + elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]: + kernel = self._extension_map["arm"]().fetch() + else: + raise Exception("not supported") + return kernel diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index f8a974b5f..0eac28d23 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -1,5 +1,4 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm -from .mha.mha import ColoAttention from .multihead_attention import MultiHeadAttention from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax @@ -8,6 +7,5 @@ __all__ = [ "MultiHeadAttention", "FusedScaleMaskSoftmax", "ScaledUpperTriangMaskedSoftmax", - "ColoAttention", "AttnMaskType", ] diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py deleted file mode 100644 index cad36e598..000000000 --- a/colossalai/kernel/cuda_native/mha/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mha import ColoAttention - -__all__ = ["ColoAttention"] diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py deleted file mode 100644 index b56d37cf0..000000000 --- a/colossalai/kernel/cuda_native/mha/mha.py +++ /dev/null @@ -1,114 +0,0 @@ -import math -from typing import Optional - -import torch -from einops import rearrange - -from ..scaled_softmax import AttnMaskType -from .flash_attn_2 import HAS_FLASH_ATTN -from .mem_eff_attn import HAS_MEM_EFF_ATTN -from .utils import Repad, SeqLenInfo, Unpad - -if HAS_FLASH_ATTN: - from .flash_attn_2 import flash_attention -if HAS_MEM_EFF_ATTN: - from .mem_eff_attn import mem_eff_attention - - -class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): - super().__init__() - assert ( - embed_dim % num_heads == 0 - ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - if scale is not None: - self.scale = scale - else: - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN: - raise Exception("flash attention can not support!") - - @staticmethod - def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - return Unpad.apply(tensor, indices) - - @staticmethod - def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: - return Repad.apply(tensor, indices, batch_size, seq_len) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - origin_attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None, - ): - attn = None - if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: - attn = flash_attention - else: - attn = mem_eff_attention - - padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 - causal = attn_mask_type is not None and attn_mask_type.value > 1 - - batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] - # unpad - seq_len_info_q = None - seq_len_info_kv = None - if padded: - # bert style, unpad process - assert ( - attn_mask is not None - ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, ( - "attention mask is supposed to have shape (batch_size, seq_len), " - + f"but got {attn_mask.dim()} dimensions." - ) - - # bert style - if tgt_len == src_len: - seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query, key, value = self.unpad( - torch.stack([query, key, value], dim=2), seq_len_info_q.indices - ).unbind(dim=1) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - seq_len_info_kv = seq_len_info_q - else: - seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) - seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( - dim=1 - ) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - - out = attn( - query, - key, - value, - seq_len_info_q, - seq_len_info_kv, - dropout_p=self.dropout, - scale=self.scale, - causal=causal, - padded=padded, - ) - - # repad - if padded: - if batch_size > 1: - out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) - out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) - - out = rearrange(out, "b s h d -> b s (h d)") - return out diff --git a/colossalai/kernel/extensions/__init__.py b/colossalai/kernel/extensions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/kernel/extensions/base_extension.py b/colossalai/kernel/extensions/base_extension.py new file mode 100644 index 000000000..8905dbf13 --- /dev/null +++ b/colossalai/kernel/extensions/base_extension.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from typing import Callable + + +class BaseExtension(ABC): + @abstractmethod + def requires_build(self) -> bool: + pass + + @abstractmethod + def build(self) -> None: + pass + + @abstractmethod + def load(self) -> Callable: + pass + + def fetch(self) -> Callable: + if self.requires_build: + self.build() + return self.load() diff --git a/colossalai/kernel/extensions/cpu_adam/__init__.py b/colossalai/kernel/extensions/cpu_adam/__init__.py new file mode 100644 index 000000000..b14f3a978 --- /dev/null +++ b/colossalai/kernel/extensions/cpu_adam/__init__.py @@ -0,0 +1,4 @@ +from .arm_extension import ArmCPUAdamExtension +from .x86_extension import X86CPUAdamExtension + +__all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"] diff --git a/colossalai/kernel/extensions/cpu_adam/arm_extension.py b/colossalai/kernel/extensions/cpu_adam/arm_extension.py new file mode 100644 index 000000000..9868059bf --- /dev/null +++ b/colossalai/kernel/extensions/cpu_adam/arm_extension.py @@ -0,0 +1,53 @@ +from ..base_extension import BaseExtension +from ..extension_builder import ExtensionBuilder + + +class ArmCPUAdamExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + self.kernel_builder = ArmCPUAdamBuilder() + self._requires_build = False + + @property + def requires_build(self) -> bool: + return self._requires_build + + def build(self): + self.kernel_builder.build() + self._requires_build = True + + def load(self): + return self.kernel_builder.load() + + +class ArmCPUAdamBuilder(ExtensionBuilder): + NAME = "arm_cpu_adam" + PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" + ext_type = "cpu" + + def __init__(self): + super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("cpu_adam_arm.cpp"), + ] + return ret + + def include_dirs(self): + return [self.csrc_abs_path("includes")] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-g", + "-Wno-reorder", + "-fopenmp", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + return [] diff --git a/colossalai/kernel/extensions/cpu_adam/x86_extension.py b/colossalai/kernel/extensions/cpu_adam/x86_extension.py new file mode 100644 index 000000000..687c91f35 --- /dev/null +++ b/colossalai/kernel/extensions/cpu_adam/x86_extension.py @@ -0,0 +1,65 @@ +from ..base_extension import BaseExtension +from ..extension_builder import ExtensionBuilder +from ..utils import append_nvcc_threads + + +class X86CPUAdamExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + self.kernel_builder = X86CPUAdamBuilder() + self._requires_build = False + + @property + def requires_build(self) -> bool: + return self._requires_build + + def build(self): + self.kernel_builder.build() + self._requires_build = True + + def load(self): + return self.kernel_builder.load() + + +class X86CPUAdamBuilder(ExtensionBuilder): + NAME = "cpu_adam" + PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" + + def __init__(self): + super().__init__(name=X86CPUAdamBuilder.NAME, prebuilt_import_path=X86CPUAdamBuilder.PREBUILT_IMPORT_PATH) + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("cpu_adam.cpp"), + ] + return ret + + def include_dirs(self): + return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-lcudart", + "-lcublas", + "-g", + "-Wno-reorder", + "-fopenmp", + "-march=native", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + extra_cuda_flags = [ + "-std=c++14", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/colossalai/kernel/extensions/extension_builder.py b/colossalai/kernel/extensions/extension_builder.py new file mode 100644 index 000000000..5849fcfa6 --- /dev/null +++ b/colossalai/kernel/extensions/extension_builder.py @@ -0,0 +1,243 @@ +# This code has been adapted from the DeepSpeed library. +# Copyright (c) Microsoft Corporation. + +# Licensed under the MIT License. +import importlib +import os +import time +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Optional, Union + +from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 + + +class ExtensionBuilder(ABC): + """ + Builder is the base class to build extensions for PyTorch. + + Args: + name (str): the name of the kernel to be built + prebuilt_import_path (str): the path where the extension is installed during pip install + """ + + ext_type: str = "cuda" + + def __init__(self, name: str, prebuilt_import_path: str): + self.name = name + self.prebuilt_import_path = prebuilt_import_path + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + # we store the op as an attribute to avoid repeated building and loading + self.cached_op_module = None + + assert prebuilt_import_path.startswith( + "colossalai._C" + ), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}" + + def relative_to_abs_path(self, code_path: str) -> str: + """ + This function takes in a path relative to the colossalai root directory and return the absolute path. + """ + op_builder_module_path = Path(__file__).parent + + # if we install from source + # the current file path will be op_builder/builder.py + # if we install via pip install colossalai + # the current file path will be colossalai/kernel/op_builder/builder.py + # this is because that the op_builder inside colossalai is a symlink + # this symlink will be replaced with actual files if we install via pypi + # thus we cannot tell the colossalai root directory by checking whether the op_builder + # is a symlink, we can only tell whether it is inside or outside colossalai + if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): + root_path = op_builder_module_path.parent.parent + elif str(op_builder_module_path).endswith("colossalai/kernel/extensions"): + root_path = op_builder_module_path.parent.parent + else: + root_path = op_builder_module_path.parent.joinpath("colossalai") + + code_abs_path = root_path.joinpath(code_path) + return str(code_abs_path) + + def get_cuda_home_include(self): + """ + return include path inside the cuda home. + """ + from torch.utils.cpp_extension import CUDA_HOME + + if CUDA_HOME is None: + raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") + cuda_include = os.path.join(CUDA_HOME, "include") + return cuda_include + + def csrc_abs_path(self, path): + return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path) + + # functions must be overrided begin + @abstractmethod + def sources_files(self) -> List[str]: + """ + This function should return a list of source files for extensions. + """ + raise NotImplementedError + + @abstractmethod + def include_dirs(self) -> List[str]: + """ + This function should return a list of include files for extensions. + """ + + @abstractmethod + def cxx_flags(self) -> List[str]: + """ + This function should return a list of cxx compilation flags for extensions. + """ + + @abstractmethod + def nvcc_flags(self) -> List[str]: + """ + This function should return a list of nvcc compilation flags for extensions. + """ + + # functions must be overrided over + def strip_empty_entries(self, args): + """ + Drop any empty strings from the list of compile and link flags + """ + return [x for x in args if len(x) > 0] + + def import_op(self): + """ + This function will import the op module by its string name. + """ + return importlib.import_module(self.prebuilt_import_path) + + def check_runtime_build_environment(self): + """ + Check whether the system environment is ready for extension compilation. + """ + try: + from torch.utils.cpp_extension import CUDA_HOME + + TORCH_AVAILABLE = True + except ImportError: + TORCH_AVAILABLE = False + CUDA_HOME = None + + if not TORCH_AVAILABLE: + raise ModuleNotFoundError( + "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions" + ) + + if CUDA_HOME is None: + raise RuntimeError( + "CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" + ) + + # make sure CUDA is available for compilation during + cuda_available = check_cuda_availability() + if not cuda_available: + raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.") + + # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not + check_system_pytorch_cuda_match(CUDA_HOME) + + def build(self, verbose: Optional[bool] = None): + """ + If the kernel is not built during pip install, it will build the kernel. + If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the + kernel is built during pip install, it can be accessed through `colossalai._C`. + + Warning: do not load this kernel repeatedly during model execution as it could slow down the training process. + + Args: + verbose (bool, optional): show detailed info. Defaults to True. + """ + if verbose is None: + verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1" + try: + # if the kernel has been pre-built during installation + # we just directly import it + op_module = self.import_op() + if verbose: + print_rank_0( + f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building." + ) + except ImportError: + # check environment + if self.ext_type == "cuda": + self.check_runtime_build_environment() + + # time the kernel compilation + start_build = time.time() + + # construct the build directory + import torch + from torch.utils.cpp_extension import load + + torch_version_major = torch.__version__.split(".")[0] + torch_version_minor = torch.__version__.split(".")[1] + torch_cuda_version = torch.version.cuda + home_directory = os.path.expanduser("~") + extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" + build_directory = os.path.join(home_directory, extension_directory) + Path(build_directory).mkdir(parents=True, exist_ok=True) + + if verbose: + print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") + + # load the kernel + op_module = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=build_directory, + verbose=verbose, + ) + + build_duration = time.time() - start_build + + # log jit compilation time + if verbose: + print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds") + + # cache the built/loaded kernel + self.cached_op_module = op_module + + def load(self, verbose: Optional[bool] = None): + """ + load the kernel during runtime. + + Args: + verbose (bool, optional): show detailed info. Defaults to True. + """ + # if the kernel has be compiled and cached, we directly use it + assert self.cached_op_module is not None, "Please build the kernel first before loading it." + return self.cached_op_module + + def builder(self) -> Union["CUDAExtension", "CppExtension"]: + """ + get a CUDAExtension instance used for setup.py + """ + from torch.utils.cpp_extension import CppExtension, CUDAExtension + + if self.ext_type == "cpp": + return CppExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args=self.strip_empty_entries(self.cxx_flags()), + ) + + return CUDAExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + "cxx": self.strip_empty_entries(self.cxx_flags()), + "nvcc": self.strip_empty_entries(self.nvcc_flags()), + }, + ) diff --git a/colossalai/kernel/extensions/flash_attention/__init__.py b/colossalai/kernel/extensions/flash_attention/__init__.py new file mode 100644 index 000000000..79c6935d2 --- /dev/null +++ b/colossalai/kernel/extensions/flash_attention/__init__.py @@ -0,0 +1,19 @@ +from .cuda_flash_attn_2_extension import HAS_FLASH_ATTN, CudaFlashAttnExtension +from .cuda_memory_efficient_attn_extension import HAS_MEM_EFF_ATTN, CudaMemoryEfficentAttnExtension +from .npu_sdpa_attn_extension import NpuSdpaAttnExtension +from .npu_triangle_attn_extension import HAS_NPU_TRIANGLE_ATTENTION, NpuTriangleAttnExtension +from .utils import AttnMaskType, Repad, SeqLenInfo, Unpad + +__all__ = [ + "CudaFlashAttnExtension", + "CudaMemoryEfficentAttnExtension", + "NpuSdpaAttnExtension", + "NpuTriangleAttnExtension", + "HAS_FLASH_ATTN", + "HAS_MEM_EFF_ATTN", + "HAS_NPU_TRIANGLE_ATTENTION", + "Unpad", + "AttnMaskType", + "Repad", + "SeqLenInfo", +] diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py similarity index 67% rename from colossalai/kernel/cuda_native/mha/flash_attn_2.py rename to colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py index de2ccaa49..99c353606 100644 --- a/colossalai/kernel/cuda_native/mha/flash_attn_2.py +++ b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py @@ -1,10 +1,14 @@ -import warnings from typing import Optional import torch +from ..base_extension import BaseExtension +from ..utils import print_rank_0 +from .utils import SeqLenInfo + def is_ampere_or_better_gpu(): + # Check Ampere GPUs or newer if torch.cuda.is_available(): device = torch.device("cuda") properties = torch.cuda.get_device_properties(device) @@ -13,31 +17,28 @@ def is_ampere_or_better_gpu(): return False -# "Check Ampere GPUs or newer" HAS_FLASH_ATTN = False +ERROR_MSG = None if is_ampere_or_better_gpu(): - HAS_FLASH_ATTN = True -else: - warnings.warn("FlashAttention only supports Ampere GPUs or newer.") - HAS_FLASH_ATTN = False -try: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + + HAS_FLASH_ATTN = True + except ImportError: + ERROR_MSG = "ImportError: please install flash_attn from https://github.com/HazyResearch/flash-attention" +else: + ERROR_MSG = "ImportError: FlashAttention only supports Ampere GPUs or newer." - HAS_FLASH_ATTN = True -except ImportError: - warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") - HAS_FLASH_ATTN = False if HAS_FLASH_ATTN: - from .utils import SeqLenInfo - def flash_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_len_info_q: SeqLenInfo, seq_len_info_kv: SeqLenInfo, + origin_attn_mask: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: float = None, @@ -77,3 +78,23 @@ if HAS_FLASH_ATTN: else: attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) return attn_out + + +class CudaFlashAttnExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + + @property + def requires_build(self): + return False + + def build(self): + pass + + def is_available(self): + if HAS_FLASH_ATTN == False: + print_rank_0(ERROR_MSG) + return HAS_FLASH_ATTN + + def load(self): + return flash_attention diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py similarity index 74% rename from colossalai/kernel/cuda_native/mha/mem_eff_attn.py rename to colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py index 649e74d61..4954ab5b1 100644 --- a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py +++ b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py @@ -1,4 +1,10 @@ -import warnings +from typing import Optional + +import torch + +from ..base_extension import BaseExtension +from ..utils import print_rank_0 +from .utils import SeqLenInfo HAS_MEM_EFF_ATTN = False try: @@ -12,19 +18,13 @@ try: HAS_MEM_EFF_ATTN = True except ImportError: - warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") - HAS_MEM_EFF_ATTN = False + pass if HAS_MEM_EFF_ATTN: """ A general attention module using the flash attention kernels from xformers: https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha """ - from typing import Optional - - import torch - - from .utils import SeqLenInfo allow_alibi = True for op in MemoryEfficientAttentionCutlassOp: @@ -36,6 +36,7 @@ if HAS_MEM_EFF_ATTN: v: torch.Tensor, seq_len_info_q: SeqLenInfo, seq_len_info_kv: SeqLenInfo, + origin_attn_mask: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: float = None, @@ -68,3 +69,23 @@ if HAS_MEM_EFF_ATTN: out = out.squeeze(0) return out + + +class CudaMemoryEfficentAttnExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + + @property + def requires_build(self) -> bool: + return False + + def build(self): + pass + + def is_available(self): + if HAS_MEM_EFF_ATTN == False: + print_rank_0("ImportError: please install xformers from https://github.com/facebookresearch/xformers") + return HAS_MEM_EFF_ATTN + + def load(self): + return mem_eff_attention diff --git a/colossalai/kernel/npu/mha/sdpa_attn.py b/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py similarity index 73% rename from colossalai/kernel/npu/mha/sdpa_attn.py rename to colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py index 2af1dbae2..7dc9d9b9b 100644 --- a/colossalai/kernel/npu/mha/sdpa_attn.py +++ b/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py @@ -1,16 +1,20 @@ import torch from einops import rearrange +from ..base_extension import BaseExtension + def npu_sdpa_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - attn_mask: torch.Tensor = None, + seq_len_info_q=None, + seq_len_info_kv=None, origin_attn_mask: torch.Tensor = None, - scale: float = 1.0, dropout_p: float = 0.0, - is_causal: bool = True, + scale: float = 1.0, + causal=None, + padded=None, ): """ The scaled dot product attention. @@ -39,3 +43,18 @@ def npu_sdpa_attention( ) output = rearrange(output, "b h s d -> b s (h d)") return output + + +class NpuSdpaAttnExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + + @property + def requires_build(self) -> bool: + return False + + def build(self): + pass + + def load(self): + return npu_sdpa_attention diff --git a/colossalai/kernel/npu/mha/triangle_attn.py b/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py similarity index 86% rename from colossalai/kernel/npu/mha/triangle_attn.py rename to colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py index 619076d5f..a760f56a1 100644 --- a/colossalai/kernel/npu/mha/triangle_attn.py +++ b/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py @@ -13,18 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import torch from einops import rearrange +from ..base_extension import BaseExtension +from ..utils import print_rank_0 + HAS_NPU_TRIANGLE_ATTENTION = False try: from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax HAS_NPU_TRIANGLE_ATTENTION = True except ImportError: - logging.warning("Import torch_npu Error.") + pass if HAS_NPU_TRIANGLE_ATTENTION: @@ -33,11 +35,13 @@ if HAS_NPU_TRIANGLE_ATTENTION: q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - attn_mask: torch.Tensor = None, + seq_len_info_q=None, + seq_len_info_kv=None, origin_attn_mask: torch.Tensor = None, - scale: float = 1.0, dropout_p: float = 0.0, - is_causal: bool = True, + scale: float = 1.0, + causal=None, + padded=None, block_size=512, ): """ @@ -113,3 +117,25 @@ if HAS_NPU_TRIANGLE_ATTENTION: # Context layer. [b, sq, hp] # ========================= return context_layer + + +class NpuTriangleAttnExtension(BaseExtension): + def __init__(self) -> None: + super().__init__() + + @property + def requires_build(self) -> bool: + return False + + def build(self): + pass + + def is_available(self): + if HAS_NPU_TRIANGLE_ATTENTION == False: + print_rank_0( + "ImportError: please install latest torch_npu with 'npu_confusion_transpose' and 'npu_scaled_masked_softmax' api." + ) + return HAS_NPU_TRIANGLE_ATTENTION + + def load(self): + return npu_triangle_attention diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/extensions/flash_attention/utils.py similarity index 96% rename from colossalai/kernel/cuda_native/mha/utils.py rename to colossalai/kernel/extensions/flash_attention/utils.py index 5f01e3ef3..0eab9e89f 100644 --- a/colossalai/kernel/cuda_native/mha/utils.py +++ b/colossalai/kernel/extensions/flash_attention/utils.py @@ -1,3 +1,4 @@ +import enum from dataclasses import dataclass from typing import Iterable, Tuple @@ -80,3 +81,9 @@ class SeqLenInfo: max_seqlen = max(seqlens) cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 diff --git a/colossalai/kernel/extensions/utils.py b/colossalai/kernel/extensions/utils.py new file mode 100644 index 000000000..3f75f952d --- /dev/null +++ b/colossalai/kernel/extensions/utils.py @@ -0,0 +1,229 @@ +import os +import re +import subprocess +import warnings +from typing import List + + +def print_rank_0(message: str) -> None: + """ + Print on only one process to avoid spamming. + """ + try: + import torch.distributed as dist + + if not dist.is_initialized(): + is_main_rank = True + else: + is_main_rank = dist.get_rank() == 0 + except ImportError: + is_main_rank = True + + if is_main_rank: + print(message) + + +def get_cuda_version_in_pytorch() -> List[int]: + """ + This function returns the CUDA version in the PyTorch build. + + Returns: + The CUDA version required by PyTorch, in the form of tuple (major, minor). + """ + import torch + + try: + torch_cuda_major = torch.version.cuda.split(".")[0] + torch_cuda_minor = torch.version.cuda.split(".")[1] + except: + raise ValueError( + "[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda" + ) + return torch_cuda_major, torch_cuda_minor + + +def get_cuda_bare_metal_version(cuda_dir) -> List[int]: + """ + Get the System CUDA version from nvcc. + + Args: + cuda_dir (str): the directory for CUDA Toolkit. + + Returns: + The CUDA version required by PyTorch, in the form of tuple (major, minor). + """ + nvcc_path = os.path.join(cuda_dir, "bin/nvcc") + + if cuda_dir is None: + raise ValueError( + f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly." + ) + + # check for nvcc path + if not os.path.exists(nvcc_path): + raise FileNotFoundError( + f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME." + ) + + # parse the nvcc -v output to obtain the system cuda version + try: + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + except: + raise ValueError( + f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}" + ) + + return bare_metal_major, bare_metal_minor + + +def check_system_pytorch_cuda_match(cuda_dir): + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch() + + if bare_metal_major != torch_cuda_major: + raise Exception( + f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) " + f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})." + "Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ." + ) + + if bare_metal_minor != torch_cuda_minor: + warnings.warn( + f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. " + "The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. " + "If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions" + ) + return True + + +def get_pytorch_version() -> List[int]: + """ + This functions finds the PyTorch version. + + Returns: + A tuple of integers in the form of (major, minor, patch). + """ + import torch + + torch_version = torch.__version__.split("+")[0] + TORCH_MAJOR = int(torch_version.split(".")[0]) + TORCH_MINOR = int(torch_version.split(".")[1]) + TORCH_PATCH = int(torch_version.split(".")[2], 16) + return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH + + +def check_pytorch_version(min_major_version, min_minor_version) -> bool: + """ + Compare the current PyTorch version with the minium required version. + + Args: + min_major_version (int): the minimum major version of PyTorch required + min_minor_version (int): the minimum minor version of PyTorch required + + Returns: + A boolean value. The value is True if the current pytorch version is acceptable and False otherwise. + """ + # get pytorch version + torch_major, torch_minor, _ = get_pytorch_version() + + # if the + if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version): + raise RuntimeError( + f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n" + "The latest stable release can be obtained from https://pytorch.org/get-started/locally/" + ) + + +def check_cuda_availability(): + """ + Check if CUDA is available on the system. + + Returns: + A boolean value. True if CUDA is available and False otherwise. + """ + import torch + + return torch.cuda.is_available() + + +def set_cuda_arch_list(cuda_dir): + """ + This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation. + Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'. + """ + cuda_available = check_cuda_availability() + + # we only need to set this when CUDA is not available for cross-compilation + if not cuda_available: + warnings.warn( + "\n[extension] PyTorch did not find available GPUs on this system.\n" + "If your intention is to cross-compile, this is not an error.\n" + "By default, Colossal-AI will cross-compile for \n" + "1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "2. Volta (compute capability 7.0)\n" + "3. Turing (compute capability 7.5),\n" + "4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n" + "\nIf you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n' + ) + + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + + arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"] + + if int(bare_metal_major) == 11: + if int(bare_metal_minor) == 0: + arch_list.append("8.0") + else: + arch_list.append("8.0") + arch_list.append("8.6") + + arch_list_str = ";".join(arch_list) + os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str + return False + return True + + +def get_cuda_cc_flag() -> List[str]: + """ + This function produces the cc flags for your GPU arch + + Returns: + The CUDA cc flags for compilation. + """ + + # only import torch when needed + # this is to avoid importing torch when building on a machine without torch pre-installed + # one case is to build wheel for pypi release + import torch + + cc_flag = [] + max_arch = "".join(str(i) for i in torch.cuda.get_device_capability()) + for arch in torch.cuda.get_arch_list(): + res = re.search(r"sm_(\d+)", arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch): + cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) + return cc_flag + + +def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]: + """ + This function appends the threads flag to your nvcc args. + + Returns: + The nvcc compilation flags including the threads flag. + """ + from torch.utils.cpp_extension import CUDA_HOME + + bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/kernel/flash_attention_loader.py new file mode 100644 index 000000000..3d0cd3975 --- /dev/null +++ b/colossalai/kernel/flash_attention_loader.py @@ -0,0 +1,185 @@ +import math +from collections import OrderedDict +from typing import Optional + +import torch +from einops import rearrange + +from colossalai.accelerator import get_accelerator + +from .base_kernel_loader import BaseKernelLoader +from .extensions.flash_attention import ( + AttnMaskType, + CudaFlashAttnExtension, + CudaMemoryEfficentAttnExtension, + NpuSdpaAttnExtension, + NpuTriangleAttnExtension, + Repad, + SeqLenInfo, + Unpad, +) +from .extensions.utils import print_rank_0 + + +class FlashAttentionLoader(BaseKernelLoader): + """ + FlashAttention Loader + + options: cuda flashh attention, cuda memory effcient attention, npu sdpa attention, npu triangle attention + + Args: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + + def __init__(self): + super().__init__( + # extension name must start with the accelerator name. E.g. npu_xxx, cuda_xxx + extension_map=OrderedDict( + cuda_flash_attn=CudaFlashAttnExtension, + cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension, + npu_sdpa_attn=NpuSdpaAttnExtension, + npu_triangle_attn=NpuTriangleAttnExtension, + ), + supported_device=["cuda", "npu"], + ) + + def fetch_kernel(self, backend: str = None): + if backend is not None: + if not self._extension_map[backend]().is_available(): + raise Exception(f"{backend} is not available for flash attention.") + return self._extension_map[backend]().fetch() + + kernel = None + accelerator_name = get_accelerator().name + assert accelerator_name in self._supported_device, f"{accelerator_name} is not supported for flash attention." + for extension_name, extension in self._extension_map.items(): + if extension_name.startswith(accelerator_name): + if extension().is_available(): + kernel = extension().fetch() + break + if kernel is None: + raise Exception("No extension for flash attention is supported") + return kernel + + +class ColoAttention(torch.nn.Module): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): + super().__init__() + assert ( + embed_dim % num_heads == 0 + ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + self.attn = FlashAttentionLoader().fetch_kernel() + + @staticmethod + def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return Unpad.apply(tensor, indices) + + @staticmethod + def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + return Repad.apply(tensor, indices, batch_size, seq_len) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + origin_attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None, + ): + """ + ColoAttention + + Args: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + origin_attn_mask: (nheads, q_seqlen, kv_seqlen) + bias: will not be used + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + # if flash attention is not applicable, switch to memory effcient attention + if self.attn.__name__ == "flash_attention" and ( + query.dtype not in [torch.float16, torch.bfloat16] or bias != None + ): + print_rank_0("flash attention is not applicable, switch to memory effcient attention") + self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_memory_efficent_attn") + + padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 + causal = attn_mask_type is not None and attn_mask_type.value > 1 + + batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] + # unpad + seq_len_info_q = None + seq_len_info_kv = None + if padded: + # bert style, unpad process + assert ( + attn_mask is not None + ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, ( + "attention mask is supposed to have shape (batch_size, seq_len), " + + f"but got {attn_mask.dim()} dimensions." + ) + + # bert style + if tgt_len == src_len: + seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query, key, value = self.unpad( + torch.stack([query, key, value], dim=2), seq_len_info_q.indices + ).unbind(dim=1) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + seq_len_info_kv = seq_len_info_q + else: + seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) + seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query = rearrange(query, "b s ... -> c (b s) ...", c=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( + dim=1 + ) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + + out = self.attn( + query, + key, + value, + seq_len_info_q=seq_len_info_q, + seq_len_info_kv=seq_len_info_kv, + origin_attn_mask=origin_attn_mask, + dropout_p=self.dropout, + scale=self.scale, + causal=causal, + padded=padded, + ) + + # repad + if padded: + if batch_size > 1: + out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) + out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) + + if len(out.shape) == 4: + out = rearrange(out, "b s h d -> b s (h d)") + return out diff --git a/colossalai/kernel/npu/__init__.py b/colossalai/kernel/npu/__init__.py deleted file mode 100644 index 6a02c7055..000000000 --- a/colossalai/kernel/npu/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mha import NPUColoAttention - -__all__ = ["NPUColoAttention"] diff --git a/colossalai/kernel/npu/mha/__init__.py b/colossalai/kernel/npu/mha/__init__.py deleted file mode 100644 index 6a02c7055..000000000 --- a/colossalai/kernel/npu/mha/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mha import NPUColoAttention - -__all__ = ["NPUColoAttention"] diff --git a/colossalai/kernel/npu/mha/mha.py b/colossalai/kernel/npu/mha/mha.py deleted file mode 100644 index ac982384e..000000000 --- a/colossalai/kernel/npu/mha/mha.py +++ /dev/null @@ -1,80 +0,0 @@ -import math -from typing import Optional - -import torch - -from .sdpa_attn import npu_sdpa_attention -from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION - - -class NPUColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None): - super().__init__() - - try: - import torch_npu # noqa - except ImportError: - raise Exception("torch_npu is not installed.") - - assert ( - embed_dim % num_heads == 0 - ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - if scale is not None: - self.scale = scale - else: - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - origin_attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: int = None, - bias: Optional[torch.Tensor] = None, - ): - """ - Implement the scaled dot product attention with softmax. - - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - scale: float. The scaling of QK^T before applying softmax. - Default to 1. - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - assert ( - len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4 - ), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}" - assert ( - query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu" - ), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}" - assert bias is None, "bias is not supported in npu colo attention" - - causal = attn_mask_type is not None and attn_mask_type.value > 1 - - if HAS_NPU_TRIANGLE_ATTENTION: - from .triangle_attn import npu_triangle_attention - - attn_fn = npu_triangle_attention - else: - attn_fn = npu_sdpa_attention - - out = attn_fn( - query, - key, - value, - attn_mask=attn_mask, - origin_attn_mask=origin_attn_mask, - dropout_p=self.dropout, - scale=self.scale, - is_causal=causal, - ) - return out diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 7d53a1dd6..b2f67cae6 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,10 +1,9 @@ import math -import platform from typing import Optional import torch -from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder +from colossalai.kernel import CPUAdamLoader from .nvme_optimizer import NVMeOptimizer @@ -78,7 +77,7 @@ class CPUAdam(NVMeOptimizer): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load() + cpu_adam = CPUAdamLoader().load() # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 55683b227..96fd3bd7b 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -6,7 +6,8 @@ import torch.distributed as dist from torch import nn from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup, get_world_size -from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed + +from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state class SeqParallelUtils: @@ -280,21 +281,3 @@ def create_randomizer_with_offset( Randomizer.increment_index() return Randomizer(seed=base_seed) - - -def get_attention_kernel(): - """ - Get the attention kernel based on the device type. - """ - from colossalai.kernel.cuda_native import AttnMaskType - - if torch.cuda.is_available(): - from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel - else: - try: - torch.npu.is_available() - from colossalai.kernel.npu import NPUColoAttention as AttentionKernel - except: - raise Exception("No available device for attention kernel!") - - return AttnMaskType, AttentionKernel diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index 00b2037fb..3522264ad 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -62,7 +62,7 @@ def forward_fn(): def get_blip2_flash_attention_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2Attention - from colossalai.kernel.cuda_native import ColoAttention + from colossalai.kernel import ColoAttention def forward( self: Blip2Attention, diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index c8a311df7..0e469b7dd 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM def get_flash_core_attention_forward(): - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.kernel import AttnMaskType, ColoAttention from .chatglm2_6b.modeling_chatglm import CoreAttention diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 8f4563537..9ab51b90e 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -719,7 +719,7 @@ class GPT2PipelineForwards: def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.kernel import AttnMaskType, ColoAttention def split_heads(tensor, num_heads, attn_head_size): """ diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index c3de197c4..9d02e1376 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,5 +1,5 @@ import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -12,14 +12,15 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer.utils import get_attention_kernel try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask + LATEST_VERSION = True except ImportError: LATEST_VERSION = False + class LlamaPipelineForwards: """ This class serves as a micro library for forward function substitution of Llama models @@ -405,7 +406,7 @@ class LlamaPipelineForwards: def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - AttnMaskType, ColoAttention = get_attention_kernel() + from colossalai.kernel import AttnMaskType, ColoAttention llama_version = 2 try: @@ -469,7 +470,12 @@ def get_llama_flash_attention_forward(): attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask, + query_states, + key_states, + value_states, + attn_mask=flash_attention_mask, + attn_mask_type=attn_mask_type, + origin_attn_mask=attention_mask, ) attn_output = self.o_proj(attn_output) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 71f2ca335..625b78bd8 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -514,7 +514,7 @@ class OPTPipelineForwards: def get_opt_flash_attention_forward(): from transformers.models.opt.modeling_opt import OPTAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.kernel import AttnMaskType, ColoAttention def forward( self: OPTAttention, diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 5a50e7379..ca3574253 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag def get_vit_flash_self_attention_forward(): from transformers.models.vit.modeling_vit import ViTSelfAttention - from colossalai.kernel.cuda_native import ColoAttention + from colossalai.kernel import ColoAttention def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 9827d4801..f67f6cd63 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.kernel import AttnMaskType, ColoAttention def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ec7644317..eee3b505a 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -35,7 +35,7 @@ from transformers.utils import ( replace_return_docstrings, ) -from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 6bbe3e4e8..c136f78a1 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -90,9 +90,9 @@ class FusedAdamKernel(AdamKernel): class CPUAdamKernel(AdamKernel): def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) - from colossalai.kernel.op_builder import CPUAdamBuilder + from colossalai.kernel import CPUAdamLoader - cpu_optim = CPUAdamBuilder().load() + cpu_optim = CPUAdamLoader().load() self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index a5c465ba0..30a30e86a 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -4,13 +4,11 @@ import pytest import torch from einops import rearrange -from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN -from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN from colossalai.testing import clear_cache_before_run, parameterize if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: - from colossalai.kernel.cuda_native import ColoAttention - from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType + from colossalai.kernel import AttnMaskType, ColoAttention DTYPE = [torch.float16, torch.bfloat16, torch.float32]