mirror of https://github.com/hpcaitech/ColossalAI
[npu] use extension for op builder (#5172)
* update extension * update cpu adam * update is * add doc for cpu adam * update kernel * update commit * update flash * update memory efficient * update flash attn * update flash attention loader * update api * fix * update doc * update example time limit * reverse change * fix doc * remove useless kernel * fix * not use warning * update * updatepull/5237/head
parent
d6df19bae7
commit
dd2c28a323
|
@ -1,7 +1,14 @@
|
|||
from .cpu_adam_loader import CPUAdamLoader
|
||||
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
|
||||
from .extensions.flash_attention import AttnMaskType
|
||||
from .flash_attention_loader import ColoAttention, FlashAttentionLoader
|
||||
|
||||
__all__ = [
|
||||
"LayerNorm",
|
||||
"FusedScaleMaskSoftmax",
|
||||
"MultiHeadAttention",
|
||||
"CPUAdamLoader",
|
||||
"FlashAttentionLoader",
|
||||
"ColoAttention",
|
||||
"AttnMaskType",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
from .extensions.base_extension import BaseExtension
|
||||
|
||||
|
||||
class BaseKernelLoader(ABC):
|
||||
"""
|
||||
Usage:
|
||||
kernel_loader = KernelLoader()
|
||||
kernel = kernel_loader.load()
|
||||
"""
|
||||
|
||||
def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]):
|
||||
self._extension_map = extension_map
|
||||
self._supported_device = supported_device
|
||||
|
||||
def run_checks(self):
|
||||
# run supported device check and other possible checks
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fetch_kernel(self):
|
||||
pass
|
||||
|
||||
def load(self):
|
||||
self.run_checks()
|
||||
return self.fetch_kernel()
|
|
@ -0,0 +1,64 @@
|
|||
import platform
|
||||
from collections import OrderedDict
|
||||
|
||||
from .base_kernel_loader import BaseKernelLoader
|
||||
from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension
|
||||
|
||||
|
||||
class CPUAdamLoader(BaseKernelLoader):
|
||||
"""
|
||||
CPU Adam Loader
|
||||
|
||||
Usage:
|
||||
# init
|
||||
cpu_adam = CPUAdamLoader().load()
|
||||
cpu_adam_op = cpu_adam.CPUAdamOptimizer(
|
||||
alpha, beta1, beta2, epsilon, weight_decay, adamw_mode,
|
||||
)
|
||||
...
|
||||
# optim step
|
||||
cpu_adam_op.step(
|
||||
step, lr, beta1, beta2, epsilon, weight_decay, bias_correction,
|
||||
params, grads, exp_avg, exp_avg_sq, loss_scale,
|
||||
)
|
||||
|
||||
Args:
|
||||
func CPUAdamOptimizer:
|
||||
alpha (float): learning rate. Default to 1e-3.
|
||||
beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9.
|
||||
beta2 (float): coefficients used for computing running averages of its square. Default to 0.99.
|
||||
epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8.
|
||||
weight_decay (float): weight decay (L2 penalty). Default to 0.
|
||||
adamw_mode (bool): whether to use the adamw. Default to True.
|
||||
func step:
|
||||
step (int): current step.
|
||||
lr (float): learning rate.
|
||||
beta1 (float): coefficients used for computing running averages of gradient.
|
||||
beta2 (float): coefficients used for computing running averages of its square.
|
||||
epsilon (float): term added to the denominator to improve numerical stability.
|
||||
weight_decay (float): weight decay (L2 penalty).
|
||||
bias_correction (bool): whether to use bias correction.
|
||||
params (torch.Tensor): parameter.
|
||||
grads (torch.Tensor): gradient.
|
||||
exp_avg (torch.Tensor): exp average.
|
||||
exp_avg_sq (torch.Tensor): exp average square.
|
||||
loss_scale (float): loss scale value.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
extension_map=OrderedDict(
|
||||
arm=ArmCPUAdamExtension,
|
||||
x86=X86CPUAdamExtension,
|
||||
),
|
||||
supported_device=["cpu"],
|
||||
)
|
||||
|
||||
def fetch_kernel(self):
|
||||
if platform.machine() == "x86_64":
|
||||
kernel = self._extension_map["x86"]().fetch()
|
||||
elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]:
|
||||
kernel = self._extension_map["arm"]().fetch()
|
||||
else:
|
||||
raise Exception("not supported")
|
||||
return kernel
|
|
@ -1,5 +1,4 @@
|
|||
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from .mha.mha import ColoAttention
|
||||
from .multihead_attention import MultiHeadAttention
|
||||
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
|
||||
|
||||
|
@ -8,6 +7,5 @@ __all__ = [
|
|||
"MultiHeadAttention",
|
||||
"FusedScaleMaskSoftmax",
|
||||
"ScaledUpperTriangMaskedSoftmax",
|
||||
"ColoAttention",
|
||||
"AttnMaskType",
|
||||
]
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
from .mha import ColoAttention
|
||||
|
||||
__all__ = ["ColoAttention"]
|
|
@ -1,114 +0,0 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from ..scaled_softmax import AttnMaskType
|
||||
from .flash_attn_2 import HAS_FLASH_ATTN
|
||||
from .mem_eff_attn import HAS_MEM_EFF_ATTN
|
||||
from .utils import Repad, SeqLenInfo, Unpad
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
from .flash_attn_2 import flash_attention
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
from .mem_eff_attn import mem_eff_attention
|
||||
|
||||
|
||||
class ColoAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
|
||||
super().__init__()
|
||||
assert (
|
||||
embed_dim % num_heads == 0
|
||||
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
|
||||
if scale is not None:
|
||||
self.scale = scale
|
||||
else:
|
||||
self.scale = 1 / math.sqrt(embed_dim // num_heads)
|
||||
self.dropout = dropout
|
||||
|
||||
if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN:
|
||||
raise Exception("flash attention can not support!")
|
||||
|
||||
@staticmethod
|
||||
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
return Unpad.apply(tensor, indices)
|
||||
|
||||
@staticmethod
|
||||
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
|
||||
return Repad.apply(tensor, indices, batch_size, seq_len)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_type: Optional[AttnMaskType] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
attn = None
|
||||
if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None:
|
||||
attn = flash_attention
|
||||
else:
|
||||
attn = mem_eff_attention
|
||||
|
||||
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
|
||||
causal = attn_mask_type is not None and attn_mask_type.value > 1
|
||||
|
||||
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
|
||||
# unpad
|
||||
seq_len_info_q = None
|
||||
seq_len_info_kv = None
|
||||
if padded:
|
||||
# bert style, unpad process
|
||||
assert (
|
||||
attn_mask is not None
|
||||
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
|
||||
assert attn_mask.dim() == 2, (
|
||||
"attention mask is supposed to have shape (batch_size, seq_len), "
|
||||
+ f"but got {attn_mask.dim()} dimensions."
|
||||
)
|
||||
|
||||
# bert style
|
||||
if tgt_len == src_len:
|
||||
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
||||
if batch_size > 1:
|
||||
query, key, value = self.unpad(
|
||||
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
|
||||
).unbind(dim=1)
|
||||
else:
|
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
||||
seq_len_info_kv = seq_len_info_q
|
||||
else:
|
||||
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
|
||||
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
||||
if batch_size > 1:
|
||||
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
|
||||
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
|
||||
dim=1
|
||||
)
|
||||
else:
|
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
||||
|
||||
out = attn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
seq_len_info_q,
|
||||
seq_len_info_kv,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale,
|
||||
causal=causal,
|
||||
padded=padded,
|
||||
)
|
||||
|
||||
# repad
|
||||
if padded:
|
||||
if batch_size > 1:
|
||||
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
|
||||
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
|
||||
|
||||
out = rearrange(out, "b s h d -> b s (h d)")
|
||||
return out
|
|
@ -0,0 +1,21 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class BaseExtension(ABC):
|
||||
@abstractmethod
|
||||
def requires_build(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> Callable:
|
||||
pass
|
||||
|
||||
def fetch(self) -> Callable:
|
||||
if self.requires_build:
|
||||
self.build()
|
||||
return self.load()
|
|
@ -0,0 +1,4 @@
|
|||
from .arm_extension import ArmCPUAdamExtension
|
||||
from .x86_extension import X86CPUAdamExtension
|
||||
|
||||
__all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"]
|
|
@ -0,0 +1,53 @@
|
|||
from ..base_extension import BaseExtension
|
||||
from ..extension_builder import ExtensionBuilder
|
||||
|
||||
|
||||
class ArmCPUAdamExtension(BaseExtension):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.kernel_builder = ArmCPUAdamBuilder()
|
||||
self._requires_build = False
|
||||
|
||||
@property
|
||||
def requires_build(self) -> bool:
|
||||
return self._requires_build
|
||||
|
||||
def build(self):
|
||||
self.kernel_builder.build()
|
||||
self._requires_build = True
|
||||
|
||||
def load(self):
|
||||
return self.kernel_builder.load()
|
||||
|
||||
|
||||
class ArmCPUAdamBuilder(ExtensionBuilder):
|
||||
NAME = "arm_cpu_adam"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
|
||||
ext_type = "cpu"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
|
||||
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
||||
|
||||
# necessary 4 functions
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path("cpu_adam_arm.cpp"),
|
||||
]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
return [self.csrc_abs_path("includes")]
|
||||
|
||||
def cxx_flags(self):
|
||||
extra_cxx_flags = [
|
||||
"-std=c++14",
|
||||
"-std=c++17",
|
||||
"-g",
|
||||
"-Wno-reorder",
|
||||
"-fopenmp",
|
||||
]
|
||||
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
|
||||
|
||||
def nvcc_flags(self):
|
||||
return []
|
|
@ -0,0 +1,65 @@
|
|||
from ..base_extension import BaseExtension
|
||||
from ..extension_builder import ExtensionBuilder
|
||||
from ..utils import append_nvcc_threads
|
||||
|
||||
|
||||
class X86CPUAdamExtension(BaseExtension):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.kernel_builder = X86CPUAdamBuilder()
|
||||
self._requires_build = False
|
||||
|
||||
@property
|
||||
def requires_build(self) -> bool:
|
||||
return self._requires_build
|
||||
|
||||
def build(self):
|
||||
self.kernel_builder.build()
|
||||
self._requires_build = True
|
||||
|
||||
def load(self):
|
||||
return self.kernel_builder.load()
|
||||
|
||||
|
||||
class X86CPUAdamBuilder(ExtensionBuilder):
|
||||
NAME = "cpu_adam"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=X86CPUAdamBuilder.NAME, prebuilt_import_path=X86CPUAdamBuilder.PREBUILT_IMPORT_PATH)
|
||||
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
||||
|
||||
# necessary 4 functions
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path("cpu_adam.cpp"),
|
||||
]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
return [self.csrc_abs_path("includes"), self.get_cuda_home_include()]
|
||||
|
||||
def cxx_flags(self):
|
||||
extra_cxx_flags = [
|
||||
"-std=c++14",
|
||||
"-std=c++17",
|
||||
"-lcudart",
|
||||
"-lcublas",
|
||||
"-g",
|
||||
"-Wno-reorder",
|
||||
"-fopenmp",
|
||||
"-march=native",
|
||||
]
|
||||
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
"-std=c++14",
|
||||
"-std=c++17",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
|
||||
]
|
||||
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
|
@ -0,0 +1,243 @@
|
|||
# This code has been adapted from the DeepSpeed library.
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
|
||||
# Licensed under the MIT License.
|
||||
import importlib
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
|
||||
|
||||
|
||||
class ExtensionBuilder(ABC):
|
||||
"""
|
||||
Builder is the base class to build extensions for PyTorch.
|
||||
|
||||
Args:
|
||||
name (str): the name of the kernel to be built
|
||||
prebuilt_import_path (str): the path where the extension is installed during pip install
|
||||
"""
|
||||
|
||||
ext_type: str = "cuda"
|
||||
|
||||
def __init__(self, name: str, prebuilt_import_path: str):
|
||||
self.name = name
|
||||
self.prebuilt_import_path = prebuilt_import_path
|
||||
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
||||
|
||||
# we store the op as an attribute to avoid repeated building and loading
|
||||
self.cached_op_module = None
|
||||
|
||||
assert prebuilt_import_path.startswith(
|
||||
"colossalai._C"
|
||||
), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}"
|
||||
|
||||
def relative_to_abs_path(self, code_path: str) -> str:
|
||||
"""
|
||||
This function takes in a path relative to the colossalai root directory and return the absolute path.
|
||||
"""
|
||||
op_builder_module_path = Path(__file__).parent
|
||||
|
||||
# if we install from source
|
||||
# the current file path will be op_builder/builder.py
|
||||
# if we install via pip install colossalai
|
||||
# the current file path will be colossalai/kernel/op_builder/builder.py
|
||||
# this is because that the op_builder inside colossalai is a symlink
|
||||
# this symlink will be replaced with actual files if we install via pypi
|
||||
# thus we cannot tell the colossalai root directory by checking whether the op_builder
|
||||
# is a symlink, we can only tell whether it is inside or outside colossalai
|
||||
if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"):
|
||||
root_path = op_builder_module_path.parent.parent
|
||||
elif str(op_builder_module_path).endswith("colossalai/kernel/extensions"):
|
||||
root_path = op_builder_module_path.parent.parent
|
||||
else:
|
||||
root_path = op_builder_module_path.parent.joinpath("colossalai")
|
||||
|
||||
code_abs_path = root_path.joinpath(code_path)
|
||||
return str(code_abs_path)
|
||||
|
||||
def get_cuda_home_include(self):
|
||||
"""
|
||||
return include path inside the cuda home.
|
||||
"""
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
return cuda_include
|
||||
|
||||
def csrc_abs_path(self, path):
|
||||
return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path)
|
||||
|
||||
# functions must be overrided begin
|
||||
@abstractmethod
|
||||
def sources_files(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of source files for extensions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def include_dirs(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of include files for extensions.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def cxx_flags(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of cxx compilation flags for extensions.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def nvcc_flags(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of nvcc compilation flags for extensions.
|
||||
"""
|
||||
|
||||
# functions must be overrided over
|
||||
def strip_empty_entries(self, args):
|
||||
"""
|
||||
Drop any empty strings from the list of compile and link flags
|
||||
"""
|
||||
return [x for x in args if len(x) > 0]
|
||||
|
||||
def import_op(self):
|
||||
"""
|
||||
This function will import the op module by its string name.
|
||||
"""
|
||||
return importlib.import_module(self.prebuilt_import_path)
|
||||
|
||||
def check_runtime_build_environment(self):
|
||||
"""
|
||||
Check whether the system environment is ready for extension compilation.
|
||||
"""
|
||||
try:
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
CUDA_HOME = None
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
raise ModuleNotFoundError(
|
||||
"PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions"
|
||||
)
|
||||
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
"CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions"
|
||||
)
|
||||
|
||||
# make sure CUDA is available for compilation during
|
||||
cuda_available = check_cuda_availability()
|
||||
if not cuda_available:
|
||||
raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.")
|
||||
|
||||
# make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not
|
||||
check_system_pytorch_cuda_match(CUDA_HOME)
|
||||
|
||||
def build(self, verbose: Optional[bool] = None):
|
||||
"""
|
||||
If the kernel is not built during pip install, it will build the kernel.
|
||||
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
|
||||
kernel is built during pip install, it can be accessed through `colossalai._C`.
|
||||
|
||||
Warning: do not load this kernel repeatedly during model execution as it could slow down the training process.
|
||||
|
||||
Args:
|
||||
verbose (bool, optional): show detailed info. Defaults to True.
|
||||
"""
|
||||
if verbose is None:
|
||||
verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1"
|
||||
try:
|
||||
# if the kernel has been pre-built during installation
|
||||
# we just directly import it
|
||||
op_module = self.import_op()
|
||||
if verbose:
|
||||
print_rank_0(
|
||||
f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building."
|
||||
)
|
||||
except ImportError:
|
||||
# check environment
|
||||
if self.ext_type == "cuda":
|
||||
self.check_runtime_build_environment()
|
||||
|
||||
# time the kernel compilation
|
||||
start_build = time.time()
|
||||
|
||||
# construct the build directory
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
torch_version_major = torch.__version__.split(".")[0]
|
||||
torch_version_minor = torch.__version__.split(".")[1]
|
||||
torch_cuda_version = torch.version.cuda
|
||||
home_directory = os.path.expanduser("~")
|
||||
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}"
|
||||
build_directory = os.path.join(home_directory, extension_directory)
|
||||
Path(build_directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if verbose:
|
||||
print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now")
|
||||
|
||||
# load the kernel
|
||||
op_module = load(
|
||||
name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_cflags=self.cxx_flags(),
|
||||
extra_cuda_cflags=self.nvcc_flags(),
|
||||
extra_ldflags=[],
|
||||
build_directory=build_directory,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
build_duration = time.time() - start_build
|
||||
|
||||
# log jit compilation time
|
||||
if verbose:
|
||||
print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds")
|
||||
|
||||
# cache the built/loaded kernel
|
||||
self.cached_op_module = op_module
|
||||
|
||||
def load(self, verbose: Optional[bool] = None):
|
||||
"""
|
||||
load the kernel during runtime.
|
||||
|
||||
Args:
|
||||
verbose (bool, optional): show detailed info. Defaults to True.
|
||||
"""
|
||||
# if the kernel has be compiled and cached, we directly use it
|
||||
assert self.cached_op_module is not None, "Please build the kernel first before loading it."
|
||||
return self.cached_op_module
|
||||
|
||||
def builder(self) -> Union["CUDAExtension", "CppExtension"]:
|
||||
"""
|
||||
get a CUDAExtension instance used for setup.py
|
||||
"""
|
||||
from torch.utils.cpp_extension import CppExtension, CUDAExtension
|
||||
|
||||
if self.ext_type == "cpp":
|
||||
return CppExtension(
|
||||
name=self.prebuilt_import_path,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
include_dirs=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
|
||||
)
|
||||
|
||||
return CUDAExtension(
|
||||
name=self.prebuilt_import_path,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
include_dirs=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_compile_args={
|
||||
"cxx": self.strip_empty_entries(self.cxx_flags()),
|
||||
"nvcc": self.strip_empty_entries(self.nvcc_flags()),
|
||||
},
|
||||
)
|
|
@ -0,0 +1,19 @@
|
|||
from .cuda_flash_attn_2_extension import HAS_FLASH_ATTN, CudaFlashAttnExtension
|
||||
from .cuda_memory_efficient_attn_extension import HAS_MEM_EFF_ATTN, CudaMemoryEfficentAttnExtension
|
||||
from .npu_sdpa_attn_extension import NpuSdpaAttnExtension
|
||||
from .npu_triangle_attn_extension import HAS_NPU_TRIANGLE_ATTENTION, NpuTriangleAttnExtension
|
||||
from .utils import AttnMaskType, Repad, SeqLenInfo, Unpad
|
||||
|
||||
__all__ = [
|
||||
"CudaFlashAttnExtension",
|
||||
"CudaMemoryEfficentAttnExtension",
|
||||
"NpuSdpaAttnExtension",
|
||||
"NpuTriangleAttnExtension",
|
||||
"HAS_FLASH_ATTN",
|
||||
"HAS_MEM_EFF_ATTN",
|
||||
"HAS_NPU_TRIANGLE_ATTENTION",
|
||||
"Unpad",
|
||||
"AttnMaskType",
|
||||
"Repad",
|
||||
"SeqLenInfo",
|
||||
]
|
|
@ -1,10 +1,14 @@
|
|||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..base_extension import BaseExtension
|
||||
from ..utils import print_rank_0
|
||||
from .utils import SeqLenInfo
|
||||
|
||||
|
||||
def is_ampere_or_better_gpu():
|
||||
# Check Ampere GPUs or newer
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
properties = torch.cuda.get_device_properties(device)
|
||||
|
@ -13,31 +17,28 @@ def is_ampere_or_better_gpu():
|
|||
return False
|
||||
|
||||
|
||||
# "Check Ampere GPUs or newer"
|
||||
HAS_FLASH_ATTN = False
|
||||
ERROR_MSG = None
|
||||
if is_ampere_or_better_gpu():
|
||||
HAS_FLASH_ATTN = True
|
||||
else:
|
||||
warnings.warn("FlashAttention only supports Ampere GPUs or newer.")
|
||||
HAS_FLASH_ATTN = False
|
||||
try:
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
||||
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention")
|
||||
HAS_FLASH_ATTN = False
|
||||
except ImportError:
|
||||
ERROR_MSG = "ImportError: please install flash_attn from https://github.com/HazyResearch/flash-attention"
|
||||
else:
|
||||
ERROR_MSG = "ImportError: FlashAttention only supports Ampere GPUs or newer."
|
||||
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
|
||||
from .utils import SeqLenInfo
|
||||
|
||||
def flash_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q: SeqLenInfo,
|
||||
seq_len_info_kv: SeqLenInfo,
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = None,
|
||||
|
@ -77,3 +78,23 @@ if HAS_FLASH_ATTN:
|
|||
else:
|
||||
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
|
||||
return attn_out
|
||||
|
||||
|
||||
class CudaFlashAttnExtension(BaseExtension):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def requires_build(self):
|
||||
return False
|
||||
|
||||
def build(self):
|
||||
pass
|
||||
|
||||
def is_available(self):
|
||||
if HAS_FLASH_ATTN == False:
|
||||
print_rank_0(ERROR_MSG)
|
||||
return HAS_FLASH_ATTN
|
||||
|
||||
def load(self):
|
||||
return flash_attention
|
|
@ -1,4 +1,10 @@
|
|||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..base_extension import BaseExtension
|
||||
from ..utils import print_rank_0
|
||||
from .utils import SeqLenInfo
|
||||
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
try:
|
||||
|
@ -12,19 +18,13 @@ try:
|
|||
|
||||
HAS_MEM_EFF_ATTN = True
|
||||
except ImportError:
|
||||
warnings.warn("please install xformers from https://github.com/facebookresearch/xformers")
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
pass
|
||||
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
"""
|
||||
A general attention module using the flash attention kernels from xformers:
|
||||
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import SeqLenInfo
|
||||
|
||||
allow_alibi = True
|
||||
for op in MemoryEfficientAttentionCutlassOp:
|
||||
|
@ -36,6 +36,7 @@ if HAS_MEM_EFF_ATTN:
|
|||
v: torch.Tensor,
|
||||
seq_len_info_q: SeqLenInfo,
|
||||
seq_len_info_kv: SeqLenInfo,
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = None,
|
||||
|
@ -68,3 +69,23 @@ if HAS_MEM_EFF_ATTN:
|
|||
out = out.squeeze(0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CudaMemoryEfficentAttnExtension(BaseExtension):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def requires_build(self) -> bool:
|
||||
return False
|
||||
|
||||
def build(self):
|
||||
pass
|
||||
|
||||
def is_available(self):
|
||||
if HAS_MEM_EFF_ATTN == False:
|
||||
print_rank_0("ImportError: please install xformers from https://github.com/facebookresearch/xformers")
|
||||
return HAS_MEM_EFF_ATTN
|
||||
|
||||
def load(self):
|
||||
return mem_eff_attention
|
|
@ -1,16 +1,20 @@
|
|||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from ..base_extension import BaseExtension
|
||||
|
||||
|
||||
def npu_sdpa_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
attn_mask: torch.Tensor = None,
|
||||
seq_len_info_q=None,
|
||||
seq_len_info_kv=None,
|
||||
origin_attn_mask: torch.Tensor = None,
|
||||
scale: float = 1.0,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = True,
|
||||
scale: float = 1.0,
|
||||
causal=None,
|
||||
padded=None,
|
||||
):
|
||||
"""
|
||||
The scaled dot product attention.
|
||||
|
@ -39,3 +43,18 @@ def npu_sdpa_attention(
|
|||
)
|
||||
output = rearrange(output, "b h s d -> b s (h d)")
|
||||
return output
|
||||
|
||||
|
||||
class NpuSdpaAttnExtension(BaseExtension):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def requires_build(self) -> bool:
|
||||
return False
|
||||
|
||||
def build(self):
|
||||
pass
|
||||
|
||||
def load(self):
|
||||
return npu_sdpa_attention
|
|
@ -13,18 +13,20 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from ..base_extension import BaseExtension
|
||||
from ..utils import print_rank_0
|
||||
|
||||
HAS_NPU_TRIANGLE_ATTENTION = False
|
||||
try:
|
||||
from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax
|
||||
|
||||
HAS_NPU_TRIANGLE_ATTENTION = True
|
||||
except ImportError:
|
||||
logging.warning("Import torch_npu Error.")
|
||||
pass
|
||||
|
||||
|
||||
if HAS_NPU_TRIANGLE_ATTENTION:
|
||||
|
@ -33,11 +35,13 @@ if HAS_NPU_TRIANGLE_ATTENTION:
|
|||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
attn_mask: torch.Tensor = None,
|
||||
seq_len_info_q=None,
|
||||
seq_len_info_kv=None,
|
||||
origin_attn_mask: torch.Tensor = None,
|
||||
scale: float = 1.0,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = True,
|
||||
scale: float = 1.0,
|
||||
causal=None,
|
||||
padded=None,
|
||||
block_size=512,
|
||||
):
|
||||
"""
|
||||
|
@ -113,3 +117,25 @@ if HAS_NPU_TRIANGLE_ATTENTION:
|
|||
# Context layer. [b, sq, hp]
|
||||
# =========================
|
||||
return context_layer
|
||||
|
||||
|
||||
class NpuTriangleAttnExtension(BaseExtension):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def requires_build(self) -> bool:
|
||||
return False
|
||||
|
||||
def build(self):
|
||||
pass
|
||||
|
||||
def is_available(self):
|
||||
if HAS_NPU_TRIANGLE_ATTENTION == False:
|
||||
print_rank_0(
|
||||
"ImportError: please install latest torch_npu with 'npu_confusion_transpose' and 'npu_scaled_masked_softmax' api."
|
||||
)
|
||||
return HAS_NPU_TRIANGLE_ATTENTION
|
||||
|
||||
def load(self):
|
||||
return npu_triangle_attention
|
|
@ -1,3 +1,4 @@
|
|||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
|
@ -80,3 +81,9 @@ class SeqLenInfo:
|
|||
max_seqlen = max(seqlens)
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
|
||||
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
padding = 1
|
||||
causal = 2
|
||||
paddedcausal = 3
|
|
@ -0,0 +1,229 @@
|
|||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
|
||||
def print_rank_0(message: str) -> None:
|
||||
"""
|
||||
Print on only one process to avoid spamming.
|
||||
"""
|
||||
try:
|
||||
import torch.distributed as dist
|
||||
|
||||
if not dist.is_initialized():
|
||||
is_main_rank = True
|
||||
else:
|
||||
is_main_rank = dist.get_rank() == 0
|
||||
except ImportError:
|
||||
is_main_rank = True
|
||||
|
||||
if is_main_rank:
|
||||
print(message)
|
||||
|
||||
|
||||
def get_cuda_version_in_pytorch() -> List[int]:
|
||||
"""
|
||||
This function returns the CUDA version in the PyTorch build.
|
||||
|
||||
Returns:
|
||||
The CUDA version required by PyTorch, in the form of tuple (major, minor).
|
||||
"""
|
||||
import torch
|
||||
|
||||
try:
|
||||
torch_cuda_major = torch.version.cuda.split(".")[0]
|
||||
torch_cuda_minor = torch.version.cuda.split(".")[1]
|
||||
except:
|
||||
raise ValueError(
|
||||
"[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda"
|
||||
)
|
||||
return torch_cuda_major, torch_cuda_minor
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
|
||||
"""
|
||||
Get the System CUDA version from nvcc.
|
||||
|
||||
Args:
|
||||
cuda_dir (str): the directory for CUDA Toolkit.
|
||||
|
||||
Returns:
|
||||
The CUDA version required by PyTorch, in the form of tuple (major, minor).
|
||||
"""
|
||||
nvcc_path = os.path.join(cuda_dir, "bin/nvcc")
|
||||
|
||||
if cuda_dir is None:
|
||||
raise ValueError(
|
||||
f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."
|
||||
)
|
||||
|
||||
# check for nvcc path
|
||||
if not os.path.exists(nvcc_path):
|
||||
raise FileNotFoundError(
|
||||
f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."
|
||||
)
|
||||
|
||||
# parse the nvcc -v output to obtain the system cuda version
|
||||
try:
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
except:
|
||||
raise ValueError(
|
||||
f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}"
|
||||
)
|
||||
|
||||
return bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def check_system_pytorch_cuda_match(cuda_dir):
|
||||
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch()
|
||||
|
||||
if bare_metal_major != torch_cuda_major:
|
||||
raise Exception(
|
||||
f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) "
|
||||
f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})."
|
||||
"Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ."
|
||||
)
|
||||
|
||||
if bare_metal_minor != torch_cuda_minor:
|
||||
warnings.warn(
|
||||
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
|
||||
"The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "
|
||||
"If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def get_pytorch_version() -> List[int]:
|
||||
"""
|
||||
This functions finds the PyTorch version.
|
||||
|
||||
Returns:
|
||||
A tuple of integers in the form of (major, minor, patch).
|
||||
"""
|
||||
import torch
|
||||
|
||||
torch_version = torch.__version__.split("+")[0]
|
||||
TORCH_MAJOR = int(torch_version.split(".")[0])
|
||||
TORCH_MINOR = int(torch_version.split(".")[1])
|
||||
TORCH_PATCH = int(torch_version.split(".")[2], 16)
|
||||
return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH
|
||||
|
||||
|
||||
def check_pytorch_version(min_major_version, min_minor_version) -> bool:
|
||||
"""
|
||||
Compare the current PyTorch version with the minium required version.
|
||||
|
||||
Args:
|
||||
min_major_version (int): the minimum major version of PyTorch required
|
||||
min_minor_version (int): the minimum minor version of PyTorch required
|
||||
|
||||
Returns:
|
||||
A boolean value. The value is True if the current pytorch version is acceptable and False otherwise.
|
||||
"""
|
||||
# get pytorch version
|
||||
torch_major, torch_minor, _ = get_pytorch_version()
|
||||
|
||||
# if the
|
||||
if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version):
|
||||
raise RuntimeError(
|
||||
f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n"
|
||||
"The latest stable release can be obtained from https://pytorch.org/get-started/locally/"
|
||||
)
|
||||
|
||||
|
||||
def check_cuda_availability():
|
||||
"""
|
||||
Check if CUDA is available on the system.
|
||||
|
||||
Returns:
|
||||
A boolean value. True if CUDA is available and False otherwise.
|
||||
"""
|
||||
import torch
|
||||
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
def set_cuda_arch_list(cuda_dir):
|
||||
"""
|
||||
This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation.
|
||||
Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'.
|
||||
"""
|
||||
cuda_available = check_cuda_availability()
|
||||
|
||||
# we only need to set this when CUDA is not available for cross-compilation
|
||||
if not cuda_available:
|
||||
warnings.warn(
|
||||
"\n[extension] PyTorch did not find available GPUs on this system.\n"
|
||||
"If your intention is to cross-compile, this is not an error.\n"
|
||||
"By default, Colossal-AI will cross-compile for \n"
|
||||
"1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
|
||||
"2. Volta (compute capability 7.0)\n"
|
||||
"3. Turing (compute capability 7.5),\n"
|
||||
"4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n"
|
||||
"\nIf you wish to cross-compile for a single specific architecture,\n"
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n'
|
||||
)
|
||||
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
|
||||
arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"]
|
||||
|
||||
if int(bare_metal_major) == 11:
|
||||
if int(bare_metal_minor) == 0:
|
||||
arch_list.append("8.0")
|
||||
else:
|
||||
arch_list.append("8.0")
|
||||
arch_list.append("8.6")
|
||||
|
||||
arch_list_str = ";".join(arch_list)
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_cuda_cc_flag() -> List[str]:
|
||||
"""
|
||||
This function produces the cc flags for your GPU arch
|
||||
|
||||
Returns:
|
||||
The CUDA cc flags for compilation.
|
||||
"""
|
||||
|
||||
# only import torch when needed
|
||||
# this is to avoid importing torch when building on a machine without torch pre-installed
|
||||
# one case is to build wheel for pypi release
|
||||
import torch
|
||||
|
||||
cc_flag = []
|
||||
max_arch = "".join(str(i) for i in torch.cuda.get_device_capability())
|
||||
for arch in torch.cuda.get_arch_list():
|
||||
res = re.search(r"sm_(\d+)", arch)
|
||||
if res:
|
||||
arch_cap = res[1]
|
||||
if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch):
|
||||
cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"])
|
||||
return cc_flag
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
|
||||
"""
|
||||
This function appends the threads flag to your nvcc args.
|
||||
|
||||
Returns:
|
||||
The nvcc compilation flags including the threads flag.
|
||||
"""
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
|
@ -0,0 +1,185 @@
|
|||
import math
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .base_kernel_loader import BaseKernelLoader
|
||||
from .extensions.flash_attention import (
|
||||
AttnMaskType,
|
||||
CudaFlashAttnExtension,
|
||||
CudaMemoryEfficentAttnExtension,
|
||||
NpuSdpaAttnExtension,
|
||||
NpuTriangleAttnExtension,
|
||||
Repad,
|
||||
SeqLenInfo,
|
||||
Unpad,
|
||||
)
|
||||
from .extensions.utils import print_rank_0
|
||||
|
||||
|
||||
class FlashAttentionLoader(BaseKernelLoader):
|
||||
"""
|
||||
FlashAttention Loader
|
||||
|
||||
options: cuda flashh attention, cuda memory effcient attention, npu sdpa attention, npu triangle attention
|
||||
|
||||
Args:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
# extension name must start with the accelerator name. E.g. npu_xxx, cuda_xxx
|
||||
extension_map=OrderedDict(
|
||||
cuda_flash_attn=CudaFlashAttnExtension,
|
||||
cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension,
|
||||
npu_sdpa_attn=NpuSdpaAttnExtension,
|
||||
npu_triangle_attn=NpuTriangleAttnExtension,
|
||||
),
|
||||
supported_device=["cuda", "npu"],
|
||||
)
|
||||
|
||||
def fetch_kernel(self, backend: str = None):
|
||||
if backend is not None:
|
||||
if not self._extension_map[backend]().is_available():
|
||||
raise Exception(f"{backend} is not available for flash attention.")
|
||||
return self._extension_map[backend]().fetch()
|
||||
|
||||
kernel = None
|
||||
accelerator_name = get_accelerator().name
|
||||
assert accelerator_name in self._supported_device, f"{accelerator_name} is not supported for flash attention."
|
||||
for extension_name, extension in self._extension_map.items():
|
||||
if extension_name.startswith(accelerator_name):
|
||||
if extension().is_available():
|
||||
kernel = extension().fetch()
|
||||
break
|
||||
if kernel is None:
|
||||
raise Exception("No extension for flash attention is supported")
|
||||
return kernel
|
||||
|
||||
|
||||
class ColoAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
|
||||
super().__init__()
|
||||
assert (
|
||||
embed_dim % num_heads == 0
|
||||
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
|
||||
if scale is not None:
|
||||
self.scale = scale
|
||||
else:
|
||||
self.scale = 1 / math.sqrt(embed_dim // num_heads)
|
||||
self.dropout = dropout
|
||||
|
||||
self.attn = FlashAttentionLoader().fetch_kernel()
|
||||
|
||||
@staticmethod
|
||||
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
return Unpad.apply(tensor, indices)
|
||||
|
||||
@staticmethod
|
||||
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
|
||||
return Repad.apply(tensor, indices, batch_size, seq_len)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_type: Optional[AttnMaskType] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
ColoAttention
|
||||
|
||||
Args:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
origin_attn_mask: (nheads, q_seqlen, kv_seqlen)
|
||||
bias: will not be used
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
# if flash attention is not applicable, switch to memory effcient attention
|
||||
if self.attn.__name__ == "flash_attention" and (
|
||||
query.dtype not in [torch.float16, torch.bfloat16] or bias != None
|
||||
):
|
||||
print_rank_0("flash attention is not applicable, switch to memory effcient attention")
|
||||
self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_memory_efficent_attn")
|
||||
|
||||
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
|
||||
causal = attn_mask_type is not None and attn_mask_type.value > 1
|
||||
|
||||
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
|
||||
# unpad
|
||||
seq_len_info_q = None
|
||||
seq_len_info_kv = None
|
||||
if padded:
|
||||
# bert style, unpad process
|
||||
assert (
|
||||
attn_mask is not None
|
||||
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
|
||||
assert attn_mask.dim() == 2, (
|
||||
"attention mask is supposed to have shape (batch_size, seq_len), "
|
||||
+ f"but got {attn_mask.dim()} dimensions."
|
||||
)
|
||||
|
||||
# bert style
|
||||
if tgt_len == src_len:
|
||||
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
||||
if batch_size > 1:
|
||||
query, key, value = self.unpad(
|
||||
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
|
||||
).unbind(dim=1)
|
||||
else:
|
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
||||
seq_len_info_kv = seq_len_info_q
|
||||
else:
|
||||
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
|
||||
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
||||
if batch_size > 1:
|
||||
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
|
||||
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
|
||||
dim=1
|
||||
)
|
||||
else:
|
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
||||
|
||||
out = self.attn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
seq_len_info_q=seq_len_info_q,
|
||||
seq_len_info_kv=seq_len_info_kv,
|
||||
origin_attn_mask=origin_attn_mask,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale,
|
||||
causal=causal,
|
||||
padded=padded,
|
||||
)
|
||||
|
||||
# repad
|
||||
if padded:
|
||||
if batch_size > 1:
|
||||
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
|
||||
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
|
||||
|
||||
if len(out.shape) == 4:
|
||||
out = rearrange(out, "b s h d -> b s (h d)")
|
||||
return out
|
|
@ -1,3 +0,0 @@
|
|||
from .mha import NPUColoAttention
|
||||
|
||||
__all__ = ["NPUColoAttention"]
|
|
@ -1,3 +0,0 @@
|
|||
from .mha import NPUColoAttention
|
||||
|
||||
__all__ = ["NPUColoAttention"]
|
|
@ -1,80 +0,0 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .sdpa_attn import npu_sdpa_attention
|
||||
from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION
|
||||
|
||||
|
||||
class NPUColoAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None):
|
||||
super().__init__()
|
||||
|
||||
try:
|
||||
import torch_npu # noqa
|
||||
except ImportError:
|
||||
raise Exception("torch_npu is not installed.")
|
||||
|
||||
assert (
|
||||
embed_dim % num_heads == 0
|
||||
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
|
||||
if scale is not None:
|
||||
self.scale = scale
|
||||
else:
|
||||
self.scale = 1 / math.sqrt(embed_dim // num_heads)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
origin_attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_type: int = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Implement the scaled dot product attention with softmax.
|
||||
|
||||
Arguments:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1.
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
assert (
|
||||
len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4
|
||||
), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}"
|
||||
assert (
|
||||
query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu"
|
||||
), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}"
|
||||
assert bias is None, "bias is not supported in npu colo attention"
|
||||
|
||||
causal = attn_mask_type is not None and attn_mask_type.value > 1
|
||||
|
||||
if HAS_NPU_TRIANGLE_ATTENTION:
|
||||
from .triangle_attn import npu_triangle_attention
|
||||
|
||||
attn_fn = npu_triangle_attention
|
||||
else:
|
||||
attn_fn = npu_sdpa_attention
|
||||
|
||||
out = attn_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
origin_attn_mask=origin_attn_mask,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale,
|
||||
is_causal=causal,
|
||||
)
|
||||
return out
|
|
@ -1,10 +1,9 @@
|
|||
import math
|
||||
import platform
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder
|
||||
from colossalai.kernel import CPUAdamLoader
|
||||
|
||||
from .nvme_optimizer import NVMeOptimizer
|
||||
|
||||
|
@ -78,7 +77,7 @@ class CPUAdam(NVMeOptimizer):
|
|||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.adamw_mode = adamw_mode
|
||||
cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load()
|
||||
cpu_adam = CPUAdamLoader().load()
|
||||
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
|
||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
|
|
|
@ -6,7 +6,8 @@ import torch.distributed as dist
|
|||
from torch import nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup, get_world_size
|
||||
from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed
|
||||
|
||||
from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state
|
||||
|
||||
|
||||
class SeqParallelUtils:
|
||||
|
@ -280,21 +281,3 @@ def create_randomizer_with_offset(
|
|||
Randomizer.increment_index()
|
||||
|
||||
return Randomizer(seed=base_seed)
|
||||
|
||||
|
||||
def get_attention_kernel():
|
||||
"""
|
||||
Get the attention kernel based on the device type.
|
||||
"""
|
||||
from colossalai.kernel.cuda_native import AttnMaskType
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel
|
||||
else:
|
||||
try:
|
||||
torch.npu.is_available()
|
||||
from colossalai.kernel.npu import NPUColoAttention as AttentionKernel
|
||||
except:
|
||||
raise Exception("No available device for attention kernel!")
|
||||
|
||||
return AttnMaskType, AttentionKernel
|
||||
|
|
|
@ -62,7 +62,7 @@ def forward_fn():
|
|||
def get_blip2_flash_attention_forward():
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
|
||||
|
||||
from colossalai.kernel.cuda_native import ColoAttention
|
||||
from colossalai.kernel import ColoAttention
|
||||
|
||||
def forward(
|
||||
self: Blip2Attention,
|
||||
|
|
|
@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
|
|||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
from colossalai.kernel import AttnMaskType, ColoAttention
|
||||
|
||||
from .chatglm2_6b.modeling_chatglm import CoreAttention
|
||||
|
||||
|
|
|
@ -719,7 +719,7 @@ class GPT2PipelineForwards:
|
|||
def get_gpt2_flash_attention_forward():
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
from colossalai.kernel import AttnMaskType, ColoAttention
|
||||
|
||||
def split_heads(tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
@ -12,14 +12,15 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
|
|||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer.utils import get_attention_kernel
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
|
||||
LATEST_VERSION = True
|
||||
except ImportError:
|
||||
LATEST_VERSION = False
|
||||
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
|
@ -405,7 +406,7 @@ class LlamaPipelineForwards:
|
|||
def get_llama_flash_attention_forward():
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
AttnMaskType, ColoAttention = get_attention_kernel()
|
||||
from colossalai.kernel import AttnMaskType, ColoAttention
|
||||
|
||||
llama_version = 2
|
||||
try:
|
||||
|
@ -469,7 +470,12 @@ def get_llama_flash_attention_forward():
|
|||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=flash_attention_mask,
|
||||
attn_mask_type=attn_mask_type,
|
||||
origin_attn_mask=attention_mask,
|
||||
)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
|
|
@ -514,7 +514,7 @@ class OPTPipelineForwards:
|
|||
def get_opt_flash_attention_forward():
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
from colossalai.kernel import AttnMaskType, ColoAttention
|
||||
|
||||
def forward(
|
||||
self: OPTAttention,
|
||||
|
|
|
@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
|
|||
def get_vit_flash_self_attention_forward():
|
||||
from transformers.models.vit.modeling_vit import ViTSelfAttention
|
||||
|
||||
from colossalai.kernel.cuda_native import ColoAttention
|
||||
from colossalai.kernel import ColoAttention
|
||||
|
||||
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
|
||||
|
|
|
@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
def get_whisper_flash_attention_forward():
|
||||
from transformers.models.whisper.modeling_whisper import WhisperAttention
|
||||
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
from colossalai.kernel import AttnMaskType, ColoAttention
|
||||
|
||||
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
|
||||
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
|
||||
|
|
|
@ -35,7 +35,7 @@ from transformers.utils import (
|
|||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
|
||||
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN
|
||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||
from colossalai.moe.layers import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
|
|
|
@ -90,9 +90,9 @@ class FusedAdamKernel(AdamKernel):
|
|||
class CPUAdamKernel(AdamKernel):
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||
from colossalai.kernel import CPUAdamLoader
|
||||
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
cpu_optim = CPUAdamLoader().load()
|
||||
|
||||
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
|
||||
|
|
|
@ -4,13 +4,11 @@ import pytest
|
|||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
|
||||
from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
|
||||
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
|
||||
from colossalai.kernel.cuda_native import ColoAttention
|
||||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||
from colossalai.kernel import AttnMaskType, ColoAttention
|
||||
|
||||
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
|
|
Loading…
Reference in New Issue