pull/5598/head
CjhHa1 2024-04-23 02:28:04 +00:00
parent f57b12df6e
commit b9305fb024
69 changed files with 9900 additions and 2 deletions

View File

@ -1 +0,0 @@
../../extensions

View File

@ -0,0 +1,140 @@
# 🔌 Extensions
## 📌 Table of Contents
- [🔌 Extensions](#-extensions)
- [📌 Table of Contents](#-table-of-contents)
- [📚 Introduction](#-introduction)
- [🪅 Design](#-design)
- [🛠 API Usage](#-api-usage)
- [🏗 Write a customized extension](#-write-a-customized-extension)
- [✏️ Acknowledgement](#-acknowledgement)
## 📚 Introduction
This module is a designed to offer extensions to the existing ColossalAI framework. It is designed to be a collection of high-performance kernels to speed up the training and inference process. Different from writing an individual kernel, the `extensions` module offers a layer of abstraction to collate kernels written in different compiler backends and for different hardware backends in an organized way. Please see the design and usage in the sections below.
## 🪅 Design
The `extensions` module is a sub-module of the `colossalai.kernel` module. This module is put at the project root directory so that it can be imported for AOT (ahead-of-time) build. At the same time, it is symbolically linked at the `colossalai.kernel.extensions` path for runtime build.
As we want to support multi-backend kernels, we have to consider multiple compiler options such as `torch.jit`, `CUDA`, `triton` and multiple hardware backends such as `CPU`, `GPU` and `NPU`. To make it easy for the users, we have abstract away the kernels into extensions and expose a single loader to the user for each kind of kernel.
For example, if the user wants to use the CPU Adam kernel, he can just call `load()` on the kernel loader. The kernel loader will automatically select the correct extension based on the current hardware and compiler backend. The user does not need to worry about the details of the kernel implementation. For example, if the user is using ARM CPU, then Arm kernel will be built and loaded. If it is a X86 CPU, then it is the X86 kernel that will be loaded.
```python
from colossalai.kernel.kernel_loader import CPUAdamLoader
# load the kernel compatible with the current hardware
kernel = CPUAdamLoader().load()
```
![](https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/extensions.png?raw=true)
## 🛠 API Usage
To make the `colossalai.kernel` easy to use, we expose some simple APIs and you can use them based on your scenario.
- Case 1: Simply load a kernel
```python
from colossalai.kernel.kernel_loader import CPUAdamLoader
# load the kernel compatible with the current hardware
kernel = CPUAdamLoader().load()
```
- Case 2: Load a specific kernel
This case applies if you are familiar with the extensions available.
```python
from colossalai.kernel.kernel_loader import CPUAdamLoader
# load the kernel by giving the kernel name
kernel = CPUAdamLoader().load(ext_name="cpu_adam_arm")
```
- Case 3: Register your own extension
This case applies if you know how to write an extension. If you do not know how, you can refer to the section below.
```python
from colossalai.kernel.kernel_loader import CPUAdamLoader
from colossalai.kernel.base_extension import _Extension
# create your own extension class
class MyExtension(_Extension):
def __init__(self):
self._name = "my_extension"
self._support_aot = True
self._support_jit = True
self.priority = 10
# implementation here
...
# register your extension
# you can use the priority value to make sure your kernel will be loaded by default
CPUAdamLoader.register_extension(MyExtension)
# load the kernel
kernel = CPUAdamLoader().load()
```
## 🏗 Write a customized extension
It is easy to write a customized extension. If you have experience writing CUDA/triton kernels, you should get familiar with the process quickly.
You just need to inherit the `_Extension` base class or other backend-specific classes such as `_CudaExtension` and implement the abstract methods. Then, you need to register your extension to the kernel loader based on the Case 3 above. The kernel loader will automatically select the correct extension based on the priority score, current hardware, compiler backend.
```python
from colossalai.kernel.base_extension import _Extension
class MyExtension(_Extension):
def __init__(self):
self._name = "my_extension"
self._support_aot = True
self._support_jit = True
self.priority = 10
def is_available(self) -> bool:
"""
Return if the required hardware can be found.
"""
...
def assert_compatible(self) -> None:
"""
Check if the hardware required by the kernel is compatible.
"""
...
def build_aot(self) -> Union["CppExtension", "CUDAExtension"]:
"""
If this kernel can be built AOT, it should return an extension object
to Python setuptools for compilation.
"""
...
def build_jit(self) -> Callable:
"""
Build extension kernel just in time.
"""
...
def load(self):
"""
The API called by the user to get the kernel.
"""
...
```
## ✏️ Acknowledgement
This module is written from scratch but we learnt a lot by looking into [DeepSpeed'
s op_builder](https://github.com/microsoft/DeepSpeed/tree/master/op_builder). We wish to acknowledge their great work and contributions to the open-source community.

View File

@ -0,0 +1,35 @@
from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension
from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension
from .inference import InferenceOpsCudaExtension
from .layernorm import LayerNormCudaExtension
from .moe import MoeCudaExtension
from .optimizer import FusedOptimizerCudaExtension
from .softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension
ALL_EXTENSIONS = [
CpuAdamArmExtension,
CpuAdamX86Extension,
LayerNormCudaExtension,
MoeCudaExtension,
FusedOptimizerCudaExtension,
InferenceOpsCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
FlashAttentionDaoCudaExtension,
FlashAttentionSdpaCudaExtension,
FlashAttentionNpuExtension,
]
__all__ = [
"CpuAdamArmExtension",
"CpuAdamX86Extension",
"LayerNormCudaExtension",
"MoeCudaExtension",
"FusedOptimizerCudaExtension",
"InferenceOpsCudaExtension",
"ScaledMaskedSoftmaxCudaExtension",
"ScaledUpperTriangleMaskedSoftmaxCudaExtension",
"FlashAttentionDaoCudaExtension",
"FlashAttentionSdpaCudaExtension",
"FlashAttentionNpuExtension",
]

View File

@ -0,0 +1,82 @@
import hashlib
import os
from abc import ABC, abstractmethod
from typing import Callable, Union
__all__ = ["_Extension"]
class _Extension(ABC):
def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1):
self._name = name
self._support_aot = support_aot
self._support_jit = support_jit
self.priority = priority
@property
def name(self):
return self._name
@property
def support_aot(self):
return self._support_aot
@property
def support_jit(self):
return self._support_jit
@staticmethod
def get_jit_extension_folder_path():
"""
Kernels which are compiled during runtime will be stored in the same cache folder for reuse.
The folder is in the path ~/.cache/colossalai/torch_extensions/<cache-folder>.
The name of the <cache-folder> follows a common format:
torch<torch_version_major>.<torch_version_minor>_<device_name><device_version>-<hash>
The <hash> suffix is the hash value of the path of the `colossalai` file.
"""
import torch
import colossalai
from colossalai.accelerator import get_accelerator
# get torch version
torch_version_major = torch.__version__.split(".")[0]
torch_version_minor = torch.__version__.split(".")[1]
# get device version
device_name = get_accelerator().name
device_version = get_accelerator().get_version()
# use colossalai's file path as hash
hash_suffix = hashlib.sha256(colossalai.__file__.encode()).hexdigest()
# concat
home_directory = os.path.expanduser("~")
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}"
cache_directory = os.path.join(home_directory, extension_directory)
return cache_directory
@abstractmethod
def is_available(self) -> bool:
"""
Check if the hardware required by the kernel is available.
"""
@abstractmethod
def assert_compatible(self) -> None:
"""
Check if the hardware required by the kernel is compatible.
"""
@abstractmethod
def build_aot(self) -> Union["CppExtension", "CUDAExtension"]:
pass
@abstractmethod
def build_jit(self) -> Callable:
pass
@abstractmethod
def load(self) -> Callable:
pass

View File

@ -0,0 +1,134 @@
import importlib
import os
import time
from abc import abstractmethod
from pathlib import Path
from typing import List
from .base_extension import _Extension
__all__ = ["_CppExtension"]
class _CppExtension(_Extension):
def __init__(self, name: str, priority: int = 1):
super().__init__(name, support_aot=True, support_jit=True, priority=priority)
# we store the op as an attribute to avoid repeated building and loading
self.cached_op = None
# build-related variables
self.prebuilt_module_path = "colossalai._C"
self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}"
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
def csrc_abs_path(self, path):
return os.path.join(self.relative_to_abs_path("csrc"), path)
def relative_to_abs_path(self, code_path: str) -> str:
"""
This function takes in a path relative to the colossalai root directory and return the absolute path.
"""
# get the current file path
# iteratively check the parent directory
# if the parent directory is "extensions", then the current file path is the root directory
# otherwise, the current file path is inside the root directory
current_file_path = Path(__file__)
while True:
if current_file_path.name == "extensions":
break
else:
current_file_path = current_file_path.parent
extension_module_path = current_file_path
code_abs_path = extension_module_path.joinpath(code_path)
return str(code_abs_path)
# functions must be overrided over
def strip_empty_entries(self, args):
"""
Drop any empty strings from the list of compile and link flags
"""
return [x for x in args if len(x) > 0]
def import_op(self):
"""
This function will import the op module by its string name.
"""
return importlib.import_module(self.prebuilt_import_path)
def build_aot(self) -> "CppExtension":
from torch.utils.cpp_extension import CppExtension
return CppExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
)
def build_jit(self) -> None:
from torch.utils.cpp_extension import load
build_directory = _Extension.get_jit_extension_folder_path()
build_directory = Path(build_directory)
build_directory.mkdir(parents=True, exist_ok=True)
# check if the kernel has been built
compiled_before = False
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
if kernel_file_path.exists():
compiled_before = True
# load the kernel
if compiled_before:
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
else:
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
build_start = time.time()
op_kernel = load(
name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_ldflags=[],
build_directory=str(build_directory),
)
build_duration = time.time() - build_start
if compiled_before:
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
else:
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
return op_kernel
# functions must be overrided begin
@abstractmethod
def sources_files(self) -> List[str]:
"""
This function should return a list of source files for extensions.
"""
@abstractmethod
def include_dirs(self) -> List[str]:
"""
This function should return a list of include files for extensions.
"""
@abstractmethod
def cxx_flags(self) -> List[str]:
"""
This function should return a list of cxx compilation flags for extensions.
"""
def load(self):
try:
op_kernel = self.import_op()
except (ImportError, ModuleNotFoundError):
# if import error occurs, it means that the kernel is not pre-built
# so we build it jit
op_kernel = self.build_jit()
return op_kernel

View File

@ -0,0 +1,4 @@
from .cpu_adam_arm import CpuAdamArmExtension
from .cpu_adam_x86 import CpuAdamX86Extension
__all__ = ["CpuAdamArmExtension", "CpuAdamX86Extension"]

View File

@ -0,0 +1,41 @@
import platform
from ..cpp_extension import _CppExtension
class CpuAdamArmExtension(_CppExtension):
def __init__(self):
super().__init__(name="cpu_adam_arm")
def is_available(self) -> bool:
# only arm allowed
return platform.machine() == "aarch64"
def assert_compatible(self) -> None:
arch = platform.machine()
assert (
arch == "aarch64"
), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}"
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("arm/cpu_adam_arm.cpp"),
]
return ret
def include_dirs(self):
return []
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
return []

View File

@ -0,0 +1,54 @@
import platform
from ..cuda_extension import _CudaExtension
from ..utils import append_nvcc_threads
class CpuAdamX86Extension(_CudaExtension):
def __init__(self):
super().__init__(name="cpu_adam_x86")
def is_available(self) -> bool:
return platform.machine() == "x86_64" and super().is_available()
def assert_compatible(self) -> None:
arch = platform.machine()
assert (
arch == "x86_64"
), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}"
super().assert_compatible()
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("x86/cpu_adam.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes"), self.get_cuda_home_include()]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-lcudart",
"-lcublas",
"-g",
"-Wno-reorder",
"-fopenmp",
"-march=native",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
extra_cuda_flags = [
"-std=c++14",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
]
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -0,0 +1,11 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
__all__ = [
"LayerNorm",
"MultiHeadAttention",
"FusedScaleMaskSoftmax",
"ScaledUpperTriangMaskedSoftmax",
"AttnMaskType",
]

View File

@ -0,0 +1,304 @@
#include "cpu_adam_arm.h"
void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4 = vdivq_f32(grad_4, loss_scale_vec);
}
float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i);
float32x4_t variance_4 =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i);
float32x4_t param_4 = simd_load_offset(_params, param_dtype, i);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4);
}
momentum_4 = vmulq_f32(momentum_4, betta1_4);
momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4);
variance_4 = vmulq_f32(variance_4, betta2_4);
grad_4 = vmulq_f32(grad_4, grad_4);
variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4);
grad_4 = vsqrtq_f32(variance_4);
grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt);
grad_4 = vdivq_f32(momentum_4, grad_4);
if (_weight_decay > 0 && _adamw_mode) {
param_4 = vfmaq_f32(param_4, param_4, weight_decay_4);
}
param_4 = vfmaq_f32(param_4, grad_4, step_size_4);
simd_store_offset(_params, param_dtype, param_4, i);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i);
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = scalar_load_offset(grads, grad_dtype, k);
if (loss_scale > 0) {
grad /= loss_scale;
}
float param = scalar_load_offset(_params, param_dtype, k);
float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k);
float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k);
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) {
param += w_decay * param;
}
param = grad * step_size + param;
scalar_store_offset(_params, param_dtype, param, k);
scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k);
scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k);
}
}
}
}
void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
float32x4_t grad_4[4];
float32x4_t momentum_4[4];
float32x4_t variance_4[4];
float32x4_t param_4[4];
#pragma unroll 4
for (int j = 0; j < 4; j++) {
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
}
momentum_4[j] =
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
variance_4[j] =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
}
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
grad_4[j] = vsqrtq_f32(variance_4[j]);
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
}
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
i + SIMD_WIDTH * j);
}
}
}
#endif
if (_param_size > rounded_size) {
Step_1(scalar_seek_offset(_params, param_dtype, rounded_size),
scalar_seek_offset(grads, grad_dtype, rounded_size),
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
exp_avg_sq_dtype, loss_scale);
}
}
void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
float32x4_t grad_4[8];
float32x4_t momentum_4[8];
float32x4_t variance_4[8];
float32x4_t param_4[8];
#pragma unroll 4
for (int j = 0; j < 8; j++) {
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
}
momentum_4[j] =
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
variance_4[j] =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
}
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
grad_4[j] = vsqrtq_f32(variance_4[j]);
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
}
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
i + SIMD_WIDTH * j);
}
}
}
#endif
if (_param_size > rounded_size) {
Step_4(scalar_seek_offset(_params, param_dtype, rounded_size),
scalar_seek_offset(grads, grad_dtype, rounded_size),
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
exp_avg_sq_dtype, loss_scale);
}
}
void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
this->IncrementStep(step, beta1, beta2);
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(),
exp_avg_sq_c.data_ptr(), params_c.numel(),
params_c.scalar_type(), grads_c.scalar_type(),
exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale);
}
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<AdamOptimizer>(m, "CPUAdamOptimizer")
.def(py::init<float, float, float, float, float, bool>())
.def("step", &AdamOptimizer::step);
}

View File

@ -0,0 +1,201 @@
#pragma once
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cmath>
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define TILE (128 * 1024 * 1024)
#if defined(__aarch64__)
#include <arm_neon.h>
#define SIMD_WIDTH 4
inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<const float32_t *>(ptr);
return vld1q_f32(ptr_f + offset);
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<const float16_t *>(ptr);
return vcvt_f32_f16(vld1_f16(ptr_h + offset));
}
// case at::ScalarType::BFloat16: {
// auto ptr_b = reinterpret_cast<const bfloat16_t *>(ptr);
// return vcvt_f32_bf16(vld1_bf16(ptr_b + offset));
// }
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) {
return simd_load_offset(ptr, dtype, 0);
}
inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<float32_t *>(ptr);
vst1q_f32(ptr_f + offset, data);
break;
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<float16_t *>(ptr);
vst1_f16(ptr_h + offset, vcvt_f16_f32(data));
break;
}
// case at::ScalarType::BFloat16: {
// auto ptr_b = reinterpret_cast<bfloat16_t *>(ptr);
// vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data));
// break;
// }
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) {
return simd_store_offset(ptr, dtype, data, 0);
}
inline float32x4_t simd_set(float value) {
auto val = static_cast<float32_t>(value);
return vdupq_n_f32(val);
}
#endif
inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return *(reinterpret_cast<const float *>(ptr) + offset);
case at::ScalarType::Half:
return static_cast<float>(
*(reinterpret_cast<const at::Half *>(ptr) + offset));
// case at::ScalarType::BFloat16:
// return static_cast<float>(
// *(reinterpret_cast<const at::BFloat16 *>(ptr) + offset));
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
*(reinterpret_cast<float *>(ptr) + offset) = data;
break;
case at::ScalarType::Half:
*(reinterpret_cast<at::Half *>(ptr) + offset) = data;
break;
// case at::ScalarType::BFloat16:
// *(reinterpret_cast<at::BFloat16 *>(ptr) + offset) = data;
break;
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return reinterpret_cast<float *>(ptr) + offset;
case at::ScalarType::Half:
return reinterpret_cast<at::Half *>(ptr) + offset;
// case at::ScalarType::BFloat16:
// return reinterpret_cast<at::BFloat16 *>(ptr) + offset;
default:
AT_ERROR("Unsupported dtype");
break;
}
}
#define STEP(SPAN) \
void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
void *_exp_avg_sq, size_t _param_size, \
at::ScalarType param_dtype, at::ScalarType grad_dtype, \
at::ScalarType exp_avg_dtype, \
at::ScalarType exp_avg_sq_dtype, float loss_scale = -1);
class AdamOptimizer {
private:
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
size_t _step;
float _bias_correction1;
float _bias_correction2;
bool _adamw_mode;
public:
AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {}
~AdamOptimizer() {}
STEP(1)
STEP(4)
STEP(8)
inline void IncrementStep(size_t step, float beta1, float beta2) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
inline void update_state(float lr, float epsilon, float weight_decay,
bool bias_correction) {
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;
_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
};

View File

@ -0,0 +1,224 @@
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#pragma once
#include <ATen/ATen.h>
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
if (HIGH_PRECISION) { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
} else { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \
PTYPE == at::ScalarType::Half) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Half && \
PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else { \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
"'"); \
}

View File

@ -0,0 +1,38 @@
#pragma once
#include <ATen/ATen.h>
#include "micros.h"
namespace colossalAI {
namespace common {
template <typename T>
struct MPTypeTrait {
using Type = float;
};
template <>
struct MPTypeTrait<float> {
using Type = float;
};
template <>
struct MPTypeTrait<at::Half> {
using Type = float;
};
template <>
struct MPTypeTrait<at::BFloat16> {
using Type = float;
};
template <bool high_precision, typename T>
struct ScalarTypeTrait {
using Type =
typename std::conditional<high_precision, typename MPTypeTrait<T>::Type,
T>::type;
};
} // namespace common
} // namespace colossalAI

View File

@ -0,0 +1,134 @@
#pragma once
#include <exception>
#include <iostream>
#include <string>
namespace colossalAI {
namespace common {
class Target {
public:
enum class OS : int {
Unk = -1,
Linux,
Windows,
};
enum class Arch : int {
Unk = -1,
X86,
Arm,
NVGPU,
AMDGPU,
Ascend,
};
enum class BitLen : int {
Unk = -1,
k32,
k64,
};
explicit Target(OS os, Arch arch, BitLen bitlen)
: os_(os), arch_(arch), bitlen_(bitlen) {}
bool defined() const {
return (os_ != OS::Unk) && (arch_ != Arch::Unk) && (bitlen_ != BitLen::Unk);
}
std::string str() const {
std::string s{"OS: "};
switch (os_) {
case OS::Unk:
s += "Unk";
break;
case OS::Linux:
s += "Linux";
break;
case OS::Windows:
s += "Windows";
break;
default:
throw std::invalid_argument("Invalid OS type!");
}
s += "\t";
s += "Arch: ";
switch (arch_) {
case Arch::Unk:
s += "Unk";
break;
case Arch::X86:
s += "X86";
break;
case Arch::Arm:
s += "Arm";
break;
case Arch::NVGPU:
s += "NVGPU";
break;
case Arch::AMDGPU:
s += "AMDGPU";
break;
case Arch::Ascend:
s += "Ascend";
break;
default:
throw std::invalid_argument("Invalid Arch type!");
}
s += "\t";
s += "BitLen: ";
switch (bitlen_) {
case BitLen::Unk:
s += "Unk";
break;
case BitLen::k32:
s += "k32";
break;
case BitLen::k64:
s += "k64";
break;
default:
throw std::invalid_argument("Invalid target bit length!");
}
return s;
}
OS os() const { return os_; }
Arch arch() const { return arch_; }
BitLen bitlen() const { return bitlen_; }
static Target DefaultX86Target();
static Target DefaultArmTarget();
static Target DefaultRocmTarget();
static Target DefaultAscendTarget();
static Target DefaultCUDATarget() {
return Target(OS::Linux, Arch::NVGPU, BitLen::k64);
}
friend std::ostream& operator<<(std::ostream& os, const Target& target);
friend bool operator==(const Target& lhs, const Target& rhs);
friend bool operator!=(const Target& lhs, const Target& rhs);
private:
OS os_{OS::Unk};
Arch arch_{Arch::Unk};
BitLen bitlen_{BitLen::Unk};
};
std::ostream& operator<<(std::ostream& os, const Target& target) {
std::cout << target.str() << std::endl;
}
bool operator==(const Target& lhs, const Target& rhs) {
return (lhs.os_ == rhs.os_) && (lhs.arch_ == rhs.arch_) &&
(lhs.bitlen_ == rhs.bitlen_);
}
bool operator!=(const Target& lhs, const Target& rhs) {
return (lhs.os_ != rhs.os_) && (lhs.arch_ != rhs.arch_) &&
(lhs.bitlen_ != rhs.bitlen_);
}
} // namespace common
} // namespace colossalAI

View File

@ -0,0 +1,75 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <stdio.h>
#include "../common/micros.h"
#include "../common/mp_type_traits.h"
template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
using MT = typename colossalAI::common::MPTypeTrait<T>::Type;
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
}
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel(
const scalar_t* __restrict__ ins_data,
scalar_t* __restrict__ outs_data,
const int64_t numel) {
using MT = typename colossalAI::common::MPTypeTrait<scalar_t>::Type;
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
const int64_t grid_size = blockDim.x * gridDim.x;
if(idx > numel) {
return;
}
for(int64_t i = idx; i < numel; i += grid_size) {
scalar_t x = ins_data[i];
scalar_t y = ins_data[i+numel];
outs_data[i] = static_cast<scalar_t>(static_cast<MT>(ACT_FN(x)) * static_cast<MT>(y));
}
}
// Note(LiuYang):This func is designed for calculation mode like
// silu(x[:half_1stdim]) * (x[half_1stdim:])
torch::Tensor silu_and_mul(const torch::Tensor& ins)
{
// Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api
// to manipulate ins_shape which is IntArrayRef
auto ins_shape = ins.sizes().vec();
ins_shape[0] = ins_shape[0]/2;
if (ins_shape[0] == 1) {
ins_shape.erase(ins_shape.begin());
}
auto outs = torch::zeros(ins_shape,ins.options());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Note(Liuyang): numel of ins must be divisible by 2
int64_t numel = ((torch::numel(ins)) >> 1);
// Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now
// I comment this part codebecause it also cost a little time to calculate a better config
// colossalAI::cuda::utils::NVGPUDevInfo dev_info(0);
// auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1);
// dim3 grid = config.grid;
// dim3 block = config.block;
dim3 grid((numel+255)/256);
dim3 block(256);
DISPATCH_FLOAT_HALF_AND_BFLOAT(
ins.scalar_type(),
"silu_and_mul",
act_and_mul_kernel<scalar_t,silu_kernel<scalar_t>><<<grid, block, 0, stream>>>(
ins.data_ptr<scalar_t>(),
outs.data_ptr<scalar_t>(),
numel
);)
AT_CUDA_CHECK(cudaGetLastError());
return outs;
}

View File

@ -0,0 +1,182 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
template<typename scalar_t, bool Aligned, int VecSize>
__global__ void context_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ cu_seqlens,
const int* __restrict__ block_tables,
const int head_num,
const int head_dim,
const int block_size,
const int batch_size,
const int block_table_stride,
const int64_t key_stride,
const int64_t value_stride
)
{
const int seq_token_id = blockIdx.x;
const int seq_id = blockIdx.y;
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];
if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
return ;
}
const int block_offset = seq_token_id % block_size;
const int hidden_size = head_num * head_dim;
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
int head_id;
int head_offset;
int64_t key_src_id;
int64_t value_src_id;
int64_t target_id;
int i = threadIdx.x * VecSize;
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}
// tail process
if (!Aligned) {
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}
}
template<typename scalar_t>
void apply_context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch)
{
int num_tokens = key.size(0);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int batch_size = block_tables.size(0);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int block_table_stride = block_tables.stride(0);
int vec_size = get_vec_size<scalar_t>(key);
bool aligned = true;
if (head_dim % vec_size != 0) {
aligned = false;
}
int thread_nums = head_num * head_dim / vec_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(max_seq_len_in_batch, batch_size);
dim3 block(std::min(thread_nums, 512));
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
do { \
context_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
cu_seqlens.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
head_num, \
head_dim, \
block_size, \
batch_size, \
block_table_stride, \
key_stride, \
value_stride \
); \
} while(0)
#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \
do { \
switch (vec_size) { \
case 1: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \
break; \
case 2: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \
break; \
case 4: \
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \
break; \
default: \
AT_ERROR("Unsupported vectorized size ", vec_size); \
break; \
} \
} while(0)
if (aligned) {
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true);
}
else {
CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false);
}
AT_CUDA_CHECK(cudaGetLastError());
}
void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch)
{
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"context_kv_cache_memcpy",
apply_context_kv_cache_memcpy<scalar_t>(
key,
value,
key_cache,
value_cache,
sequence_lengths,
cu_seqlens,
block_tables,
max_seq_len_in_batch
);)
}

View File

@ -0,0 +1,162 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
template<typename scalar_t, bool Aligned, int VecSize>
__global__ void decode_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables,
const int head_num,
const int head_dim,
const int block_size,
const int64_t key_stride,
const int64_t value_stride,
const int block_table_stride
)
{
const int seq_id = blockIdx.x;
const int seq_len = sequence_lengths[seq_id] - 1;
const int block_offset = seq_len % block_size;
const int block_id = block_tables[seq_id * block_table_stride + seq_len / block_size];
const int hidden_size = head_num * head_dim;
if ( block_id < 0 ) {
return ;
}
int i = threadIdx.x * VecSize;
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
}
if (!Aligned) {
for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
}
}
}
template<typename scalar_t>
void apply_decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
{
int num_tokens = key.size(0);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int block_table_stride = block_tables.stride(0);
int vec_size = get_vec_size<scalar_t>(key);
bool aligned = true;
if (head_dim % vec_size != 0) {
aligned = false;
}
int thread_nums = head_num * head_dim / vec_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512));
#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \
do { \
decode_kv_cache_memcpy_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
head_num, \
head_dim, \
block_size, \
key_stride, \
value_stride, \
block_table_stride \
); \
} while(0)
#define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned, __vec_size) \
do { \
switch (__vec_size) { \
case 1: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 1); \
break; \
case 2: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 2); \
break; \
case 4: \
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, 4); \
break; \
default: \
AT_ERROR("Unsupported vectorized size ", __vec_size); \
break; \
} \
} while(0)
if (aligned) {
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(true, vec_size);
}
else {
DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH_VEC_SIZE_CASE(false, vec_size);
}
AT_CUDA_CHECK(cudaGetLastError());
}
void decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
{
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"decode_kv_cache_memcpy",
apply_decode_kv_cache_memcpy<scalar_t>(
key,
value,
key_cache,
value_cache,
sequence_lengths,
block_tables
);)
}

View File

@ -0,0 +1,74 @@
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <functional>
#include "../utils/micros.h"
// Note(LiuYang): This file provides base math operation for data type
// include POD and cuda built-in type such as half and __nv_bfloat16
namespace colossalAI {
namespace cuda {
namespace funcs {
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = at::Half;
};
template <>
struct TypeConverter<at::Half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = at::BFloat16;
};
template <>
struct TypeConverter<at::BFloat16> {
using Type = __nv_bfloat162;
};
template <typename From, typename To>
struct CastFunctor : public std::unary_function<From, To> {
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
};
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMT, \
FUNCTION_MODIFIER) \
template <> \
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
FUNCTION_MODIFIER TO operator()(FROM val) { return STMT; } \
};
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, nv_bfloat162,
__float2bfloat162_rn(val), DEVICE)
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
} // namespace funcs
} // namespace cuda
} // namespace colossalAI

View File

@ -0,0 +1,92 @@
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <functional>
#include "../utils/micros.h"
namespace colossalAI {
namespace cuda {
namespace funcs {
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
// Note(LiuYang): This file provides base math operation for data type
// include POD and cuda built-in type such as half and __nv_bfloat16
template <typename LT, typename RT, typename RET, BinaryOpType Op>
struct BinaryOpFunctor;
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \
FUNCTION_MODIFIER, ARGS...) \
template <ARGS> \
struct BinaryOpFunctor<T, T, T, BINARY_OP_TYPE> \
: public std::binary_function<T, T, T> { \
FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \
};
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs),
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs),
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd,
__hadd(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd,
__hadd2(lhs, rhs), DEVICE)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
__hadd(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd,
__hadd2(lhs, rhs), DEVICE)
#else
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
__float2bfloat16(__bfloat162float(lhs) +
__bfloat162float(rhs)),
DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat162, BinaryOpType::kAdd,
__floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs),
__high2float(lhs) + __high2float(rhs)),
DEVICE)
#endif
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul,
__hmul(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul,
__hmul2(lhs, rhs), DEVICE)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
__hmul(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul,
__hmul2(lhs, rhs), DEVICE)
#else
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
__float2bfloat16(__bfloat162float(lhs) *
__bfloat162float(rhs)),
DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat162, BinaryOpType::kMul,
__floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs),
__high2float(lhs) * __high2float(rhs)),
DEVICE)
#endif
#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION
} // namespace funcs
} // namespace cuda
} // namespace colossalAI

View File

@ -0,0 +1,481 @@
// in transformers source code, huggingface uses fp16 to compute rope so we follow the same precision
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
#include "../common/mp_type_traits.h"
template <typename scalar_t, typename m_scalar_t, int VecSize>
__device__ void apply_emb_rotary_compute(
scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr,
const m_scalar_t* __restrict__ sin_ptr, const int64_t stride,
const int token_id, const int shard_block_size, const int half_head_dim,
const int head_num, const int head_dim) {
scalar_t x[VecSize];
scalar_t y[VecSize];
scalar_t out_x[VecSize];
scalar_t out_y[VecSize];
for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim;
i += blockDim.x * VecSize) {
const int head_offset = i % half_head_dim;
const int shard_offset =
(head_offset / shard_block_size) * shard_block_size +
(head_offset % shard_block_size) / VecSize;
const int64_t addr_offset =
token_id * stride + (i / half_head_dim) * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(x, src + addr_offset);
copy_vector<scalar_t, VecSize>(y, src + addr_offset + half_head_dim);
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
}
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
copy_vector<scalar_t, VecSize>(src + addr_offset + half_head_dim, out_y);
}
}
template <typename scalar_t, int VecSize>
__device__ void apply_kv_memcopy(
scalar_t* __restrict__ src, scalar_t* __restrict__ cache,
const int64_t stride, const int token_id, const int block_id,
const int hidden_size, const int block_size, const int block_offset,
const int head_dim, const int half_head_dim) {
for (int i = threadIdx.x * VecSize; i < hidden_size / 2;
i += blockDim.x * VecSize) {
const int head_id = i / half_head_dim;
const int head_offset = i % half_head_dim;
const int64_t src_id = token_id * stride + head_id * head_dim + head_offset;
const int64_t target_id = block_id * hidden_size * block_size +
head_id * block_size * head_dim +
block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(cache + target_id, src + src_id);
copy_vector<scalar_t, VecSize>(cache + target_id + half_head_dim,
src + src_id + half_head_dim);
}
}
template <typename scalar_t, typename m_scalar_t, int VecSize>
__device__ void cos_sin_memory_access(
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id,
const int shard_block_size, const int cos_stride, const int sin_stride,
const int half_head_dim) {
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
// We assume that the value of head_dim is less than 128*128.
const int shard_offset = (i % shard_block_size) / VecSize;
const int shard_head =
(i / shard_block_size) * shard_block_size + i % VecSize * 32;
cos_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(cos[token_id * cos_stride + i]);
sin_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(sin[token_id * sin_stride + i]);
}
}
template <typename scalar_t, typename m_scalar_t, int VecSize>
__device__ void apply_k_rotary_emb_compute(
scalar_t* __restrict__ key, scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables, const int64_t key_stride,
const int64_t value_stride, const int token_id,
const int block_table_stride, const int head_num, const int head_dim,
const int kv_head_num, const int block_size, const int half_head_dim,
const int shard_block_size) {
const int seq_len = sequence_lengths[token_id] - 1;
const int block_offset = seq_len % block_size;
const int block_id =
block_tables[token_id * block_table_stride + seq_len / block_size];
if (block_id < 0) {
return;
}
scalar_t x[VecSize];
scalar_t y[VecSize];
scalar_t out_x[VecSize];
scalar_t out_y[VecSize];
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
i += blockDim.x * VecSize) {
const int head_offset = i % half_head_dim;
const int shard_offset =
(head_offset / shard_block_size) * shard_block_size +
(head_offset % shard_block_size) / VecSize;
const int64_t addr_offset =
token_id * key_stride + (i / half_head_dim) * head_dim + head_offset;
const int64_t target_id = block_id * head_num * head_dim * block_size +
(i / half_head_dim) * block_size * head_dim +
block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(x, key + addr_offset);
copy_vector<scalar_t, VecSize>(y, key + addr_offset + half_head_dim);
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
}
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim,
out_y);
}
// apply value memcopy
apply_kv_memcopy<scalar_t, VecSize>(
value, value_cache, value_stride, token_id, block_id, head_num * head_dim,
block_size, block_offset, head_dim, half_head_dim);
}
template<typename scalar_t, typename m_scalar_t, int VecSize>
__global__ void rotary_embedding_and_cache_copy_kernel(
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
scalar_t* __restrict__ value,
const scalar_t* __restrict__ cos,
const scalar_t* __restrict__ sin,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables,
const int64_t query_stride,
const int64_t key_stride,
const int64_t value_stride,
const int64_t half_shard_element_num,
const int cos_stride,
const int sin_stride,
const int block_table_stride,
const int head_num,
const int head_dim,
const int kv_head_num,
const int block_size
) {
const int token_id = blockIdx.x;
const int half_head_dim = head_dim / 2;
const int shard_block_size = VecSize * 32;
extern __shared__ char shard_ptr[];
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads();
//compute query
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key and copy kv
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
}
template<typename scalar_t, typename m_scalar_t, int VecSize>
__global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
const scalar_t* __restrict__ cos,
const scalar_t* __restrict__ sin,
const int64_t query_stride,
const int64_t key_stride,
const int64_t half_shard_element_num,
const int cos_stride,
const int sin_stride,
const int head_num,
const int head_dim,
const int kv_head_num
) {
const int token_id = blockIdx.x;
const int half_head_dim = head_dim / 2;
const int shard_block_size = VecSize * 32;
extern __shared__ char shard_ptr[];
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
// apply cos_sin memcopy
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
__syncthreads();
//compute query
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
}
template<typename scalar_t, bool high_precision>
void apply_rotary_embedding_and_cache_copy(
at::Tensor& query, // [num_tokens, head_num, head_dim]
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
{
int num_tokens = query.size(0);
int head_num = query.size(1);
int head_dim = query.size(2);
int kv_head_num = key.size(1);
int block_size = key_cache.size(2);
int64_t query_stride = query.stride(0);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int cos_stride = cos.stride(0);
int sin_stride = sin.stride(0);
int block_table_stride = block_tables.stride(0);
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
int vec_size = get_vec_size<scalar_t>(query);
if ((head_dim / 2) % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int thread_nums = head_num * head_dim / vec_size / 2;
const int shard_block_size = vec_size * 32 * 2;
dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512));
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
switch (vec_size) {
case 1:
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
break;
case 2:
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
break;
case 4:
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
}
AT_CUDA_CHECK(cudaGetLastError());
}
template<typename scalar_t, bool high_precision>
void apply_rotary_embedding(
at::Tensor& query, // [total_tokens, head_num, head_dim]
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [total_tokens, head_dim]
at::Tensor& sin // [total_tokens, head_dim]
){
int num_tokens = query.size(0);
int head_num = query.size(1);
int head_dim = query.size(2);
int kv_head_num = key.size(1);
int query_stride = query.stride(0);
int key_stride = key.stride(0);
int cos_stride = cos.stride(0);
int sin_stride = sin.stride(0);
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
int vec_size = get_vec_size<scalar_t>(query);
if ((head_dim / 2) % vec_size != 0) {
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
vec_size = 1;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int thread_nums = head_num * head_dim / vec_size / 2;
const int shard_block_size = vec_size * 32 * 2;
dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512));
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
switch (vec_size) {
case 1:
rotary_embedding_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
query_stride,
key_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
head_num,
head_dim,
kv_head_num
);
break;
case 2:
rotary_embedding_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
query_stride,
key_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
head_num,
head_dim,
kv_head_num
);
break;
case 4:
rotary_embedding_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
query_stride,
key_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
head_num,
head_dim,
kv_head_num
);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
break;
}
AT_CUDA_CHECK(cudaGetLastError());
}
void rotary_embedding_and_cache_copy(
at::Tensor& query, // [num_tokens, head_num, head_dim]
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision)
{
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
high_precision,
query.scalar_type(),
"rotary_embedding_and_cache_copy",
apply_rotary_embedding_and_cache_copy<scalar_t, high_precision>(
query,
key,
value,
cos,
sin,
key_cache,
value_cache,
sequence_lengths,
block_tables
);)
}
void rotary_embedding(
at::Tensor& query, // [total_tokens, head_num, head_dim]
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [total_tokens, head_dim]
at::Tensor& sin, // [total_tokens, head_dim]
bool high_precision
){
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
high_precision,
query.scalar_type(),
"rotary_embedding",
apply_rotary_embedding<scalar_t, high_precision>(
query,
key,
cos,
sin
);)
}

View File

@ -0,0 +1,215 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
#include "stdio.h"
template <typename scalar_t, bool Aligned, int VecSize>
__device__ void apply_cos_and_sin_memcopy(
scalar_t* __restrict__ cos,
scalar_t* __restrict__ sin,
const scalar_t* __restrict__ cos_cache_ptr,
const scalar_t* __restrict__ sin_cache_ptr,
const int* __restrict__ sequence_lengths,
const int head_dim,
const int dest_offset_id,
const int src_offset_id
) {
int begin_id = threadIdx.x * VecSize;
for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){
copy_vector<scalar_t, VecSize>(cos + dest_offset_id + begin_id, cos_cache_ptr + src_offset_id + begin_id);
copy_vector<scalar_t, VecSize>(sin + dest_offset_id + begin_id, sin_cache_ptr + src_offset_id + begin_id);
}
if (!Aligned) {
for (; begin_id < head_dim; ++begin_id ) {
cos[dest_offset_id + begin_id] = cos_cache_ptr[src_offset_id + begin_id];
sin[dest_offset_id + begin_id] = sin_cache_ptr[src_offset_id + begin_id];
}
}
}
template <typename scalar_t, bool Aligned, int VecSize>
__global__ void apply_get_context_cos_and_sin_kernel(
scalar_t* __restrict__ cos,
scalar_t* __restrict__ sin,
const scalar_t* __restrict__ cos_cache_ptr,
const scalar_t* __restrict__ sin_cache_ptr,
const int* __restrict__ sequence_lengths,
const int* __restrict__ cumsum_lengths,
const int batch_size,
const int head_dim
) {
int token_id = blockIdx.x;
if ( token_id >= sequence_lengths[blockIdx.y] ) {
return ;
}
int src_offset_id = token_id * head_dim;
int dest_offset_id = src_offset_id;
if (blockIdx.y > 0) {
dest_offset_id += cumsum_lengths[blockIdx.y - 1] * head_dim;
}
apply_cos_and_sin_memcopy<scalar_t, Aligned, VecSize>(
cos,
sin,
cos_cache_ptr,
sin_cache_ptr,
sequence_lengths,
head_dim,
dest_offset_id,
src_offset_id
);
}
template <typename scalar_t, bool Aligned, int VecSize>
__global__ void apply_get_decode_cos_and_sin_kernel(
scalar_t* __restrict__ cos,
scalar_t* __restrict__ sin,
const scalar_t* __restrict__ cos_cache_ptr,
const scalar_t* __restrict__ sin_cache_ptr,
const int* __restrict__ sequence_lengths,
const int batch_size,
const int head_dim
) {
int src_offset_id = ( sequence_lengths[blockIdx.y] - 1 ) * head_dim;
int dest_offset_id = blockIdx.y * head_dim;
apply_cos_and_sin_memcopy<scalar_t, Aligned, VecSize>(
cos,
sin,
cos_cache_ptr,
sin_cache_ptr,
sequence_lengths,
head_dim,
dest_offset_id,
src_offset_id
);
}
template<typename scalar_t>
void apply_get_cos_and_sin(
at::Tensor& cos_cache, // [max_rotary_position, head_dim]
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
int max_seq_len_in_batch,
bool is_prompts
) {
int token_num = cos.size(0);
int head_dim = cos.size(1);
int batch_size = sequence_lengths.size(0);
at::Tensor cumsum_lengths;
int vec_size = get_vec_size<scalar_t>(cos);
bool aligned = true;
if (head_dim % vec_size != 0) {
aligned = false;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int block_size_y;
int block_size_x;
if (is_prompts) {
block_size_y = batch_size;
block_size_x = max_seq_len_in_batch;
// TODO: The cumsum operation can be fused into get_cos_and_sin kernel later on.
cumsum_lengths = torch::cumsum(sequence_lengths, 0, torch::kInt32);
}
else{
block_size_y = batch_size;
block_size_x = 1;
}
int thread_nums = (head_dim + vec_size - 1) / vec_size;
dim3 grid(block_size_x, block_size_y);
dim3 block(std::min(thread_nums, 512));
#define GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, __vec_size) \
do { \
if (is_prompts){ \
apply_get_context_cos_and_sin_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
cos.data_ptr<scalar_t>(), \
sin.data_ptr<scalar_t>(), \
cos_cache.data_ptr<scalar_t>(), \
sin_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
cumsum_lengths.data_ptr<int>(), \
batch_size, \
head_dim \
); \
} \
else { \
apply_get_decode_cos_and_sin_kernel<scalar_t, __aligned, __vec_size><<<grid, block, 0, stream>>>( \
cos.data_ptr<scalar_t>(), \
sin.data_ptr<scalar_t>(), \
cos_cache.data_ptr<scalar_t>(), \
sin_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
batch_size, \
head_dim \
); \
} \
} while(0)
#define GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(__aligned) \
do { \
switch (vec_size) { \
case 1: \
GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 1); \
break; \
case 2: \
GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 2); \
break; \
case 4: \
GET_COS_AND_SIN_KERNEL_LAUNCH(__aligned, 4); \
break; \
default: \
AT_ERROR("Unsupported vectorized size ", vec_size); \
break; \
} \
} while(0)
if (aligned) {
GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(true);
}
else {
GET_COS_AND_SIN_KERNEL_LAUNCH_VEC_SIZE_CASE(false);
}
AT_CUDA_CHECK(cudaGetLastError());
}
void get_cos_and_sin(
at::Tensor& cos_cache, // [max_rotary_position, head_dim]
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
int max_seq_len_in_batch,
bool is_prompts
) {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
cos.scalar_type(),
"get_cos_and_sin",
apply_get_cos_and_sin<scalar_t>(
cos_cache,
sin_cache,
cos,
sin,
sequence_lengths,
max_seq_len_in_batch,
is_prompts
);)
}

View File

@ -0,0 +1,184 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "../funcs/op_functor.h"
namespace colossalAI {
namespace cuda {
namespace utils {
const float kReduceFloatInfNeg = -100000000.f;
const float kReduceFloatInfPos = 100000000.f;
const int kWarpSize = 32;
const unsigned int kWarpReduceMask = 0xffffffff;
enum class ReduceType { kMax = 0, kSum };
template <typename T, ReduceType rtype>
struct GetOpForReduceType;
template <typename T>
struct GetOpForReduceType<T, ReduceType::kMax> {
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kMax>;
};
template <typename T>
struct GetOpForReduceType<T, ReduceType::kSum> {
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kAdd>;
};
#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
for (int offset = 0; offset < LANES; ++offset) { \
*(VAL_PTR + offset) = \
OP(*(VAL_PTR + offset), \
__shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \
}
#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, OP, LANES) \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 16, 32, OP, LANES) \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 8, 32, OP, LANES) \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 4, 32, OP, LANES) \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 2, 32, OP, LANES) \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 1, 32, OP, LANES)
#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, MASK, VAL_PTR, OP, LANES, \
DEFAULT_VALUE, REDUCE_TYPE) \
__shared__ T shm[LANES][32]; \
int lane_id = threadIdx.x & 0x1f; \
int warp_id = threadIdx.x >> 5; \
\
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
if (lane_id == 0) { \
for (int offset = 0; offset < LANES; ++offset) { \
shm[offset][warp_id] = *(VAL_PTR + offset); \
} \
} \
__syncthreads(); \
\
for (int offset = 0; offset < LANES; ++offset) { \
*(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \
? shm[offset][lane_id] \
: static_cast<T>(DEFAULT_VALUE); \
} \
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR);
template <typename T, ReduceType rtype, int lanes>
__forceinline__ __device__ void warp_reduce(T* pval) {
typename GetOpForReduceType<T, rtype>::Op op;
COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, op, lanes);
}
template <typename T, ReduceType rtype>
__forceinline__ __device__ constexpr T GetDefaultValueForBlockReduce() {
if constexpr (rtype == ReduceType::kSum) {
return static_cast<T>(0.0f);
} else if constexpr (rtype == ReduceType::kMax) {
return static_cast<T>(kReduceFloatInfNeg);
}
}
template <typename T, ReduceType rtype, int lanes>
__forceinline__ __device__ void block_reduce(T* pval) {
constexpr T kDefaultValue = GetDefaultValueForBlockReduce<T, rtype>();
typename GetOpForReduceType<T, rtype>::Op op;
COLOSSAL_BLOCK_REDUCE_IMPL(T, kWarpReduceMask, pval, op, lanes, kDefaultValue,
rtype);
}
#undef COLOSSAL_SHFL_FUNCTION
#undef COLOSSAL_WARP_REDUCE_IMPL
#undef COLOSSAL_BLOCK_REDUCE_IMPL
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(
T* x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
T* x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final =
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
} // namespace utils
} // namespace cuda
} // namespace colossalAI

View File

@ -0,0 +1,683 @@
/*This code from NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <cuda.h>
#include <cuda_runtime.h>
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/DeviceUtils.cuh"
#include "../common/micros.h"
template <typename U>
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
mu = lmean;
U delta2 = curr - lmean;
sigma2 = sigma2 + delta * delta2;
}
template <typename U>
__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB,
U& mu, U& sigma2, U& count) {
U delta = muB - mu;
U nA = count;
U nB = countB;
count = count + countB;
U nX = count;
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA * mu + nB * muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
sigma2 = U(0);
}
}
template <typename T, typename U>
__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1,
const int n2, const int i1, U& mu, U& sigma2,
U* buf) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U count = U(0);
mu = U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T* lvals = vals + i1 * n2;
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l + k]);
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U* ubuf = (U*)buf;
U* ibuf = (U*)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2 * wrt_y] = mu;
ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2 * threadIdx.y];
U sigma2B = ubuf[2 * threadIdx.y + 1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1] / U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2 / U(n2), 0);
}
}
}
template <>
__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals,
const int n1, const int n2, const int i1,
float& mu, float& sigma2, float* buf) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float count = 0.0f;
mu = float(0);
sigma2 = float(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const at::Half* lvals = vals + i1 * n2;
int l = 8 * thrx;
if ((((size_t)lvals) & 3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l + 7 < n2; l += 8 * numx) {
for (int k = 0; k < 8; k += 2) {
float2 curr = __half22float2(*((__half2*)(lvals + l + k)));
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float* ubuf = (float*)buf;
float* ibuf = (float*)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2 * wrt_y] = mu;
ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2 * threadIdx.y];
float sigma2B = ubuf[2 * threadIdx.y + 1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1] / float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2 / float(n2), 0);
}
}
}
template <typename U>
U rsqrt(U v) {
return U(1) / sqrt(v);
}
template <>
float rsqrt(float v) {
return rsqrtf(v);
}
template <>
double rsqrt(double v) {
return rsqrt(v);
}
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of
// this struct by putting an undefined symbol in the function body so it won't
// compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory<float> {
__device__ float* getPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
} // namespace
template <typename T, typename U, typename V>
__global__ void cuApplyLayerNorm(V* __restrict__ output_vals,
U* __restrict__ mean, U* __restrict__ invvar,
const T* __restrict__ vals, const int n1,
const int n2, const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu, sigma2;
cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf);
const T* lvals = vals + i1 * n2;
V* ovals = output_vals + i1 * n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
}
} else {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu;
invvar[i1] = c_invvar;
}
}
}
template <typename T, typename U, typename V>
__device__ void cuLoadWriteStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,
const T* input, const V* dout, const int i1_end, const int n2,
const U* __restrict__ mean, const U* __restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] =
curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template <typename T, typename U, typename V>
__device__ void cuLoadAddStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,
const T* input, const V* dout, const int i1_end, const int n2,
const U* __restrict__ mean, const U* __restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] +=
curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template <typename T, typename U, typename V>
__global__ void cuComputePartGradGammaBeta(
const V* __restrict__ dout, const T* __restrict__ input, const int n1,
const int n2, const U* __restrict__ mean, const U* __restrict__ invvar,
U epsilon, U* part_grad_gamma, U* part_grad_beta) {
const int numsegs_n1 =
(n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
const int i1_beg_plus_one =
(blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x + 1;
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int thr_load_row_off =
(threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
// blockDim.y + (blockDim.y -
// 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar);
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
i1_block += blockDim.y * blockDim.y) {
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k * blockDim.y;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename U, typename V>
__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size, const int n1,
const int n2, V* grad_gamma,
V* grad_beta) {
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr =
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions;
++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx + nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template <typename T, typename U, typename V>
__global__ void cuComputeGradInput(const V* __restrict__ dout,
const T* __restrict__ input, const int n1,
const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar, U epsilon,
const V* gamma, T* grad_input) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1 * n2;
const V* k_dout = dout + i1 * n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k];
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
}
}
template <typename T, typename U, typename V>
void HostApplyLayerNorm(V* output, U* mean, U* invvar, const T* input, int n1,
int n2, double epsilon, const V* gamma, const V* beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32, 4, 1);
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
}
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar,
at::Tensor* input, int n1, int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma, at::Tensor* beta, double epsilon) {
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
HostApplyLayerNorm(output->data_ptr<scalar_t_out>(),
mean->data_ptr<float>(), invvar->data_ptr<float>(),
input->data_ptr<scalar_t_in>(), n1, n2, epsilon,
gamma != NULL ? gamma->data_ptr<scalar_t_out>() : NULL,
beta != NULL ? beta->data_ptr<scalar_t_out>() : NULL);)
}
template <typename T, typename U, typename V>
void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar,
at::Tensor* input, int n1, int n2, const V* gamma,
const V* beta, double epsilon, T* grad_input,
V* grad_gamma, V* grad_beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32, 4, 1);
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
const int nshared2_a =
2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty(
{part_size, n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, input->data_ptr<T>(), n1, n2, mean, invvar, U(epsilon),
part_grad_gamma.data_ptr<U>(), part_grad_beta.data_ptr<U>());
const dim3 threads3(32, 8, 1);
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.data_ptr<U>(), part_grad_beta.data_ptr<U>(), part_size,
n1, n2, grad_gamma, grad_beta);
}
// compute grad_input
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32, 4, 1);
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout, input->data_ptr<T>(), n1, n2, mean, invvar, U(epsilon), gamma,
grad_input);
}
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean,
at::Tensor* invvar, at::Tensor* input, int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma, at::Tensor* beta,
double epsilon, at::Tensor* grad_input,
at::Tensor* grad_gamma, at::Tensor* grad_beta) {
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(),
"cuda_layer_norm_gradient_kernel",
HostLayerNormGradient(
dout->data_ptr<scalar_t_out>(), mean->data_ptr<float>(),
invvar->data_ptr<float>(), input, n1, n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->data_ptr<scalar_t_out>() : NULL,
gamma != NULL ? beta->data_ptr<scalar_t_out>() : NULL, epsilon,
grad_input->data_ptr<scalar_t_in>(),
gamma != NULL ? grad_gamma->data_ptr<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->data_ptr<scalar_t_out>() : NULL);)
}

View File

@ -0,0 +1,662 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include "block_reduce.h"
using colossalAI::cuda::utils::block_reduce;
using colossalAI::cuda::utils::ReduceType;
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row + idx, pack);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, pack);
BlockStore(ts_store).Store(src_row + idx, pack);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row1 + idx, pack);
BlockStore(ts_store).Store(dst_row2 + idx, pack);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack1[pack_size], pack2[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row1 + idx, pack1);
BlockLoad(ts_load).Load(dst_row2 + idx, pack2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack1[i] += pack2[i];
}
BlockStore(ts_store).Store(src_row + idx, pack1);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] *= weight;
}
BlockStore(ts_store).Store(dst_row + idx, pack);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
T *weight_grad, const T weight, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T grad[pack_size], tokens[pack_size];
float thread_sum = 0;
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row + idx, tokens);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum += grad[i] * tokens[i];
grad[i] *= weight;
}
BlockStore(ts_store).Store(src_row + idx, grad);
}
block_reduce<float, ReduceType::kSum, 1>(&thread_sum);
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum);
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
const T weight1, const T weight2,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack1[pack_size], pack2[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row1 + idx, pack1);
BlockLoad(ts_load).Load(src_row2 + idx, pack2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack1[i] = pack1[i] * weight1 + pack2[i] * weight2;
}
BlockStore(ts_store).Store(dst_row + idx, pack1);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T *tks_row1, T *tks_row2, T *weight_grad1,
T *weight_grad2, const T weight1,
const T weight2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size],
sgrad2[pack_size];
float thread_sum[2] = {0, 0};
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row1 + idx, tokens1);
BlockLoad(ts_load).Load(tks_row2 + idx, tokens2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum[0] += grad[i] * tokens1[i];
thread_sum[1] += grad[i] * tokens2[i];
sgrad1[i] = weight1 * grad[i];
sgrad2[i] = weight2 * grad[i];
}
BlockStore(ts_store).Store(src_row1 + idx, sgrad1);
BlockStore(ts_store).Store(src_row2 + idx, sgrad2);
}
block_reduce<float, ReduceType::kSum, 2>(thread_sum);
if (threadIdx.x == 0)
*weight_grad1 = static_cast<T>(thread_sum[0]);
else if (threadIdx.x == 1)
*weight_grad2 = static_cast<T>(thread_sum[1]);
}
// DISPATCH KERNELS --------------------------------
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols);
else if (indicator1 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row1, cols);
else if (indicator2 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row2, cols);
else
return;
}
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols);
else if (indicator1 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row1, cols);
else if (indicator2 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row2, cols);
else
return;
}
template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
int *mask1, int *mask2, int *dest1,
int *dest2, const int h) {
int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_fwd_selector<T, block_size, pack_size>(
batch_tokens + (row * h), expert_input + (dest1[row] * h),
expert_input + (dest2[row] * h), h, mask1[row], indicator2);
}
template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
int *mask2, int *dest1, int *dest2,
const int h) {
int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_bwd_selector<T, block_size, pack_size>(
tokens_grad + (row * h), expert_grad + (dest1[row] * h),
expert_grad + (dest2[row] * h), h, mask1[row], indicator2);
}
// COMBINE KERNELS --------------------------------
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
const int cols, const T weight1,
const T weight2, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
weight1, weight2, cols);
else if (indicator1 != 0)
moe_cb_one_fwd<T, block_size, pack_size>(src_row1, dst_row, weight1, cols);
else if (indicator2 != 0)
moe_cb_one_fwd<T, block_size, pack_size>(src_row2, dst_row, weight2, cols);
else
return;
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
const int cols, T *tks_row1, T *tks_row2,
T *wt_grad1, T *wt_grad2, const T weight1,
const T weight2, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
tks_row1, tks_row2, wt_grad1,
wt_grad2, weight1, weight2, cols);
else if (indicator1 != 0)
moe_cb_one_bwd<T, block_size, pack_size>(src_row1, dst_row, tks_row1,
wt_grad1, weight1, cols);
else if (indicator2 != 0)
moe_cb_one_bwd<T, block_size, pack_size>(src_row2, dst_row, tks_row2,
wt_grad2, weight2, cols);
else
return;
}
template <typename T, int block_size, int pack_size>
__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
T *logits, int *mask1, int *mask2, int *dest1,
int *dest2, const int e, const int c,
const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e);
moe_cb_fwd_selector<T, block_size, pack_size>(
expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h),
combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row],
indicator2);
}
template <typename T, int block_size, int pack_size>
__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad, int *mask1,
int *mask2, int *dest1, int *dest2,
const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
moe_cb_bwd_selector<T, block_size, pack_size>(
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
tokens_grad + (row * h), h, tks + (dest1[row] * h),
tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1],
row_log[eid2], mask1[row], indicator2);
}
// CUMSUM KERNEL --------------------------------
template <int block_size, int pack_size>
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
const int e) {
assert(s % pack_size == 0);
constexpr int bpack_size = block_size * pack_size;
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
__shared__ int temp[block_size + 1];
int pack[pack_size];
for (int idx = 0; idx < s; idx += bpack_size) {
int offset = 1;
if (idx + tps < s) {
temp[tid] = inputs[tps * e + bid];
#pragma unroll
for (int i = 1; i < pack_size; ++i) {
pack[i] = inputs[(tps + i) * e + bid];
}
#pragma unroll
for (int i = 1; i < pack_size; ++i) {
temp[tid] += pack[i];
}
}
for (int i = block_size >> 1; i > 0; i >>= 1) {
__syncthreads();
if (tid < i) {
int j = offset * (2 * tid + 1) - 1;
temp[j + offset] += temp[j];
}
offset <<= 1;
}
if (tid == 0) {
temp[block_size] = temp[block_size - 1];
temp[block_size - 1] = 0;
}
for (int i = 1; i < block_size; i <<= 1) {
offset >>= 1;
__syncthreads();
if (tid < i) {
int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j];
temp[j] = temp[k];
temp[k] += ts;
}
}
__syncthreads();
if (tid == 0) temp[0] = temp[block_size];
__syncthreads();
if (idx + tps < s) {
temp[tid + 1] += last_sum;
#pragma unroll
for (int i = pack_size - 1; i > 0; --i) {
outputs[(tps + i) * e + bid] = temp[tid + 1];
temp[tid + 1] -= pack[i];
}
outputs[tps * e + bid] = temp[tid + 1];
}
__syncthreads();
last_sum += temp[0];
inputs += bpack_size * e;
outputs += bpack_size * e;
}
}
// LAUNCH FUNCTIONS --------------------------------
template <typename T>
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
int *mask2, int *dest1, int *dest2, const int s,
const int h) {
if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 512)
moe_dpch_fwd_kernel<T, 32, 8>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 1024)
moe_dpch_fwd_kernel<T, 32, 16>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 2048)
moe_dpch_fwd_kernel<T, 64, 16>
<<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else
moe_dpch_fwd_kernel<T, 128, 16>
<<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
}
template <typename T>
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
int *dest1, int *dest2, const int s, const int h) {
if (h < 256)
moe_dpch_bwd_kernel<T, 32, 4>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 512)
moe_dpch_bwd_kernel<T, 32, 8>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 1024)
moe_dpch_bwd_kernel<T, 32, 16>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 2048)
moe_dpch_bwd_kernel<T, 64, 16>
<<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else
moe_dpch_bwd_kernel<T, 128, 16>
<<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
}
template <typename T>
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2, int *dest1, int *dest2,
const int s, const int e, const int c, const int h) {
if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2,
e, c, h);
else if (h < 512)
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2,
e, c, h);
else if (h < 1024)
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2,
e, c, h);
else if (h < 2048)
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2,
e, c, h);
else
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1,
dest2, e, c, h);
}
template <typename T>
void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T *logits_grad, int *mask1, int *mask2, int *dest1,
int *dest2, const int s, const int e, const int c,
const int h) {
if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h);
else // if (h < 512)
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h);
// else if (h < 1024)
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// dest1, dest2, e, c, h);
// else
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// dest1, dest2, e, c, h);
}
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
else if (s <= 512)
cumsum_kernel<512, 1><<<e, 512>>>(inputs, outputs, s, e);
else if (s <= 1024)
cumsum_kernel<1024, 1><<<e, 1024>>>(inputs, outputs, s, e);
else if (s <= 2048)
cumsum_kernel<1024, 2><<<e, 1024>>>(inputs, outputs, s, e);
else
cumsum_kernel<1024, 4><<<e, 1024>>>(inputs, outputs, s, e);
}
// API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
}
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
auto res = torch::zeros(
{ec, h},
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
batch_tokens.scalar_type(), "moe dispatch forward",
moe_dpch_fwd_launch<scalar_t>(
batch_tokens.data_ptr<scalar_t>(), res.data_ptr<scalar_t>(),
mask[0].data_ptr<int>(), k == 1 ? nullptr : mask[1].data_ptr<int>(),
dest_idx[0].data_ptr<int>(),
k == 1 ? dest_idx[0].data_ptr<int>() : dest_idx[1].data_ptr<int>(), s, h));
return res;
}
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
auto res = torch::zeros(
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
expert_grad.scalar_type(), "moe dispatch backward",
moe_dpch_bwd_launch<scalar_t>(
res.data_ptr<scalar_t>(), expert_grad.data_ptr<scalar_t>(),
mask[0].data_ptr<int>(), k == 1 ? nullptr : mask[1].data_ptr<int>(),
dest_idx[0].data_ptr<int>(),
k == 1 ? dest_idx[0].data_ptr<int>() : dest_idx[1].data_ptr<int>(), s, h));
return res;
}
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype());
auto res = torch::zeros(
{s, h},
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
expert_tokens.scalar_type(), "moe combine forward",
moe_cb_fwd_launch<scalar_t>(
expert_tokens.data_ptr<scalar_t>(), res.data_ptr<scalar_t>(),
logits.data_ptr<scalar_t>(), mask[0].data_ptr<int>(),
k == 1 ? nullptr : mask[1].data_ptr<int>(), dest_idx[0].data_ptr<int>(),
k == 1 ? dest_idx[0].data_ptr<int>() : dest_idx[1].data_ptr<int>(), s, e, c,
h));
return res;
}
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype());
auto egrad = torch::zeros(
{e * c, h},
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())),
wgrad = torch::zeros(
{s, e}, torch::dtype(logits.dtype()).device(logits.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
tokens_grad.scalar_type(), "moe combine backward",
moe_cb_bwd_launch<scalar_t>(
tokens_grad.data_ptr<scalar_t>(), egrad.data_ptr<scalar_t>(),
expert_tokens.data_ptr<scalar_t>(), logits.data_ptr<scalar_t>(),
wgrad.data_ptr<scalar_t>(), mask[0].data_ptr<int>(),
k == 1 ? nullptr : mask[1].data_ptr<int>(), dest_idx[0].data_ptr<int>(),
k == 1 ? dest_idx[0].data_ptr<int>() : dest_idx[1].data_ptr<int>(), s, e, c,
h));
return {egrad, wgrad};
}
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert(mask.dim() == 2);
assert(mask.dtype() == torch::kInt32);
const int s = mask.size(0), e = mask.size(1);
auto res =
torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
cumsum_launch(mask.data_ptr<int>(), res.data_ptr<int>(), s, e);
return res;
}

View File

@ -0,0 +1,146 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "../common/micros.h"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
} adamMode_t;
using MATH_T = float;
template <typename T_g, typename T_p>
struct AdamFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta1_correction,
const float beta2_correction, const float epsilon, const float lr,
adamMode_t mode, const float decay, const float div_scale) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_g *g = (T_g *)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T_p *p = (T_p *)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T_p *m = (T_p *)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T_p *v = (T_p *)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (div_scale > 0) r_g[ii] /= div_scale;
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction, const float weight_decay,
const float div_scale) {
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
tensor_lists[0][0].scalar_type(), tensor_lists[1][0].scalar_type(), 0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay, div_scale);)
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -0,0 +1,130 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#include "../common/micros.h"
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template <int n>
struct TensorListMetadata {
void *addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a
// full int.
int start_tensor_this_launch;
};
template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size,
volatile int *noop_flag, T tl,
U callable, ArgTypes... args) {
// Hand the chunk information to the user-supplied functor to process however
// it likes.
callable(chunk_size, noop_flag, tl, args...);
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
int block_size, int chunk_size, const at::Tensor &noop_flag,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
ArgTypes... args) {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size();
l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0,
"Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory =
(contiguous_memory ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(),
"Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor =
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) {
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
if (chunk == chunks_this_tensor - 1) {
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3
// << std::endl;
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
} else {
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3
// << std::endl;
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
for (int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}

View File

@ -0,0 +1,387 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "../common/micros.h"
#include "include/block_reduce.h"
#define BLOCK_SIZE 512
#define ILP 4
using colossalAI::cuda::utils::block_reduce;
using colossalAI::cuda::utils::reduce_block_into_lanes;
using colossalAI::cuda::utils::reduce_block_into_lanes_max_op;
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
template <typename T>
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
int src_offset) {
typedef
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
}
template <typename x_t>
struct L2NormFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t *x = (x_t *)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be
// sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x;
i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]);
vals[ii] += next * next;
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]);
vals[ii] += next * next;
}
}
}
}
float val = 0.f;
for (int i = 0; i < ILP; i++) val += vals[i];
float final = reduce_block_into_lanes(s_vals, val);
if (threadIdx.x == 0) {
if (!isfinite(final))
*noop_gmem =
1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final;
if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
max_chunks_per_tensor +
chunk_idx] = final;
}
}
};
// Probably better to template, but since we are not likely to support other
// norm
template <typename x_t>
struct MaxNormFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t *x = (x_t *)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be
// sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x;
i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
}
}
float val = 0.f;
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));
float final = reduce_block_into_lanes_max_op(s_vals, val);
if (threadIdx.x == 0) {
if (!isfinite(final))
*noop_gmem =
1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
max_chunks_per_tensor +
chunk_idx] = final;
}
}
};
__global__ void cleanup(float *output, float *output_per_tensor, float *ret,
float *ret_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
__shared__ float vals[512];
if (blockIdx.x == 0) {
float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(final);
}
if (per_tensor) {
float *output_this_tensor =
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
float val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);
}
}
__global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
float *ret_per_tensor, bool per_tensor,
int max_chunks_per_tensor, int norm_type,
float alpha, float beta) {
__shared__ float vals[512];
if (blockIdx.x == 0) {
float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x];
if (norm_type == 0) {
float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;
} else {
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
}
}
if (per_tensor) {
float *output_this_tensor =
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
if (norm_type == 0) {
float val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));
float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0)
ret_per_tensor[blockIdx.x] =
alpha * ret_per_tensor[blockIdx.x] + beta * final;
} else {
float val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0)
ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] *
ret_per_tensor[blockIdx.x] +
beta * final);
}
}
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor =
per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor =
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor =
at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
per_tensor, max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to
// end. I could get rid of these by hacking the functor + multi tensor harness
// with persistence logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
ret.data_ptr<float>(),
per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
// Compute and update grad norm
// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
// L-2: gn = sqrt(a * gn^2 + b * n^2)
// L-inf: gn = a * gn + b * n
void multi_tensor_norm_out_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor out,
const float alpha, const float beta, const int norm_type) {
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(),
"noop flag should be on the same device as tensors");
// we don't need global thus uses empty here
auto output = at::empty({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor =
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
// Although it is single write then read, still need to be zero
// Since tailing element also participate cleanup
output_per_tensor =
at::zeros({ntensors * max_chunks_per_tensor}, float_options);
if (norm_type == 0) {
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
MaxNormFunctor<scalar_t_0>(), output.data_ptr<float>(),
output_per_tensor.data_ptr<float>(), true, max_chunks_per_tensor);)
} else {
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(),
output_per_tensor.data_ptr<float>(), true, max_chunks_per_tensor);)
}
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to
// end. I could get rid of these by hacking the functor + multi tensor harness
// with persistence logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
// Adding the following device guard since it happens sometimes that the
// tensors are on one device and the cuda stream is on another device which
// results in ILLEGAL MEM ACCESS error.
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup_v2<<<ntensors, 512, 0, stream>>>(
output.data_ptr<float>(), output_per_tensor.data_ptr<float>(),
ret.data_ptr<float>(), out.data_ptr<float>(), true, max_chunks_per_tensor,
norm_type, alpha, beta);
return;
}

View File

@ -0,0 +1,354 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_lamb.cu
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "../common/micros.h"
#define BLOCK_SIZE 512
#define ILP 4
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
template <typename T>
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
int src_offset) {
typedef
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
}
typedef enum {
MOMENT_MODE_0 = 0, // L2 regularization mode
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
} adamMode_t;
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
using MATH_T = float;
template <typename T>
struct LAMBStage1Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta3,
const float beta1_correction, const float beta2_correction,
const float epsilon, adamMode_t mode, const float decay,
const float *global_grad_norm, const float max_global_grad_norm) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm =
(*global_grad_norm) > max_global_grad_norm
? (*global_grad_norm) / max_global_grad_norm
: 1.0f;
T *g = (T *)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T *p = (T *)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T *m = (T *)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T *v = (T *)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) &&
is_aligned(p) && is_aligned(m) && is_aligned(v)) {
T l_g[ILP];
T l_p[ILP];
T l_m[ILP];
T l_v[ILP];
for (int i_start = threadIdx.x;
i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(l_g, g, 0, i_start);
if (decay != 0) load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_g[ii] = l_g[ii];
if (decay == 0) {
r_p[ii] = MATH_T(0);
} else {
r_p[ii] = l_p[ii];
}
r_m[ii] = l_m[ii];
r_v[ii] = l_v[ii];
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay * r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
} else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
l_p[ii] = r_p[ii];
l_m[ii] = r_m[ii];
l_v[ii] = r_v[ii];
}
// store
load_store(g, l_p, i_start, 0);
load_store(m, l_m, i_start, 0);
load_store(v, l_v, i_start, 0);
}
} else {
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
// special ?optimization? for lamb stage 1
if (decay == 0) {
r_p[ii] = MATH_T(0);
} else {
r_p[ii] = p[i];
}
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay * r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
} else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
g[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template <typename T>
struct LAMBStage2Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_param_norm, const float *per_tensor_update_norm,
const float learning_rate, const float decay, bool use_nvlamb) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0)) {
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f)
? learning_rate * (param_norm / update_norm)
: learning_rate;
}
T *update = (T *)tl.addresses[0][tensor_loc];
update += chunk_idx * chunk_size;
T *p = (T *)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) &&
is_aligned(update)) {
T r_p[ILP];
T r_update[ILP];
for (int i_start = threadIdx.x;
i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_p, p, 0, i_start);
load_store(r_update, update, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_p[ii] = static_cast<MATH_T>(r_p[ii]) -
(ratio * static_cast<MATH_T>(r_update[ii]));
}
load_store(p, r_p, i_start, 0);
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_p[ILP];
MATH_T r_update[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
}
}
}
}
}
};
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int bias_correction,
const float weight_decay, const int grad_averaging,
const int mode, at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python) {
using namespace at;
// Master weight and 32bit momentum(potentially changing) is not handled by
// this So we assume every tensor are all in the same type
bool use_nvlamb =
use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
// Handle grad averaging mode
float beta3 = 1.0f;
if (grad_averaging == 1) beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
tensor_lists.begin() + 1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin() + 1,
tensor_lists.begin() + 2);
// Compute per tensor param norm
auto param_norm_tuple =
multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all
// the time We can also grab list of empty tensor to avoid this, but I'd like
// to save space/cpu code
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1, bias_correction2, epsilon,
(adamMode_t)mode, weight_decay,
global_grad_norm.data_ptr<float>(), max_grad_norm);)
// Compute update norms
auto update_norm_tuple =
multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
std::vector<std::vector<at::Tensor>> grad_param_list(
tensor_lists.begin(), tensor_lists.begin() + 2);
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list,
LAMBStage2Functor<scalar_t_0>(),
std::get<1>(param_norm_tuple).data_ptr<float>(),
std::get<1>(update_norm_tuple).data_ptr<float>(),
lr, weight_decay, use_nvlamb);)
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -0,0 +1,125 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
#include "multi_tensor_apply.cuh"
#include "../common/micros.h"
#define BLOCK_SIZE 512
#define ILP 4
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
template <typename T>
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
int src_offset) {
typedef
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
}
template <typename in_t, typename out_t>
struct ScaleFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int *noop_gmem,
TensorListMetadata<2> &tl,
float scale) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
in_t *in = (in_t *)tl.addresses[0][tensor_loc];
in += chunk_idx * chunk_size;
out_t *out = (out_t *)tl.addresses[1][tensor_loc];
out += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
bool finite = true;
in_t r_in[ILP];
out_t r_out[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) &&
is_aligned(out)) {
for (int i_start = threadIdx.x;
i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_in, in, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
}
// store
load_store(out, r_out, i_start, 0);
}
} else {
// Non-divergent exit condition for __syncthreads, not necessary here
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i];
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point
// unrolling the write loop, since writes just fire off once their LDGs
// arrive. Put another way, the STGs are dependent on the LDGs, but not
// on each other. There is still compute ILP benefit from unrolling the
// loop though.
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) out[i] = r_out[ii];
}
}
}
if (!finite)
*noop_gmem =
1; // Blindly fire off a write. These will race but that's ok.
}
};
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale) {
using namespace at;
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF(
tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ScaleFunctor<scalar_t_0, scalar_t_1>(),
scale);))
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}

View File

@ -0,0 +1,167 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_runtime.h>
#include "../common/micros.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
/**
* Perform fused SGD on multiple buffers
* N: number of tensors
* tl[0] : gradients
* tl[1] : weights
* tl[2] : momentum buffers
* tl[3] : fp16 weights (if appropriate)
* wd : weight_decay (scalar)
* momentum : momentum (scalar)
* dampening : momentum dampening (scalar)
* lr : learning rate (scalar)
* nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/
template <typename T_grad, typename T_weight>
struct SGDFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl,
float wd, float momentum, float dampening, float lr, bool nesterov,
bool first_run, bool wd_after_momentum, float scale) {
// Early exit if we don't need to do anything
if (*noop_gmem) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
grad_in += chunk_idx * chunk_size;
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
weight_in += chunk_idx * chunk_size;
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
mom_in += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_grads[ii] = 0;
incoming_weights[ii] = 0;
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
// apply weight decay before momentum if necessary
if (wd != 0.f && !wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
if (momentum != 0.f) {
if (!first_run)
incoming_moms[ii] = incoming_moms[ii] * momentum +
(1.f - dampening) * incoming_grads[ii];
else // initialize momentums to current incoming grads
incoming_moms[ii] = incoming_grads[ii];
if (nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii];
else
incoming_grads[ii] = incoming_moms[ii];
}
// Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
// adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]);
// also write out the new momentum
if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
}
}
}
}
};
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd, float momentum, float dampening, float lr,
bool nesterov, bool first_run,
bool wd_after_momentum, float scale) {
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
"expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type
// 1. fp16, fp16, fp16
// 2. fp32, fp32, fp32
// 3. fp16, fp32, fp32
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half && num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<at::Half, at::Half>(), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum,
scale);
}
// Case 2. fp32, fp32, fp32
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float && num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<float, float>(), wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 3. fp16, fp32, fp32
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float && num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<at::Half, float>(), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum,
scale);
} else {
AT_ERROR(
"multi_tensor_sgd only supports some combinations of gradient & weight "
"types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type,
", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -0,0 +1,84 @@
#include <torch/extension.h>
void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]
void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
at::Tensor& block_tables, // [batch_size, max_seq_len]
int max_seq_len_in_batch);
void rotary_embedding(
torch::Tensor& query, // [total_tokens, head_num, head_dim]
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
torch::Tensor& cos, // [total_tokens, head_dim]
torch::Tensor& sin, // [total_tokens, head_dim]
bool high_precision);
void rotary_embedding_and_cache_copy(
torch::Tensor& query, // [num_tokens, head_num, head_dim]
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
torch::Tensor& cos, // [num_tokens, head_dim]
torch::Tensor& sin, // [num_tokens, head_dim]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables, // [batch_size, max_seq_len]
bool high_precision);
torch::Tensor silu_and_mul(const torch::Tensor& ins);
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon);
void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon);
void get_cos_and_sin(at::Tensor& cos_cache, // [max_rotary_position, head_dim]
at::Tensor& sin_cache, // [max_rotary_position, head_dim]
at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
int max_seq_len_in_batch, bool is_prompts);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the context stage.");
m.def(
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
"Performing Rotary Embedding-related calculations and KVCache Memcopy.");
m.def("rotary_embedding", &rotary_embedding,
"Performing Rotary Embedding-related calculations.");
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
m.def("rms_layernorm", &rms_layernorm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
"In-place fused Add and RMS Normalization.");
m.def("get_cos_and_sin", &get_cos_and_sin, "Get cos and sin from the cache.");
}

View File

@ -0,0 +1,141 @@
/*This code from NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include <cassert>
#include <vector>
#include "../../common/micros.h"
namespace {
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert(input.sizes()[i + idiff] == normalized_shape[i]);
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma,
at::Tensor beta) {
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim)
.equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input, normalized_shape, n1, n2);
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, int &n1, int &n2) {
check_args(input, normalized_shape, n1, n2);
check_args(normalized_shape, gamma, beta);
}
} // namespace
void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at::Tensor *input, int n1, int n2,
at::IntArrayRef normalized_shape, at::Tensor *gamma,
at::Tensor *beta, double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output =
at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean =
at::empty({n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape,
&gamma, &beta, epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
at::Tensor *invvar, at::Tensor *input, int n1,
int n2, at::IntArrayRef normalized_shape,
at::Tensor *gamma, at::Tensor *beta,
double epsilon, at::Tensor *grad_input,
at::Tensor *grad_gamma, at::Tensor *grad_beta);
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input,
at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine,
"LayerNorm backward (CUDA)");
}

View File

@ -0,0 +1,97 @@
#include <torch/extension.h>
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx);
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor moe_dispatch_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx);
}
torch::Tensor moe_dispatch_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx);
}
torch::Tensor moe_combine_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask,
dest_idx);
}
std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens,
logits, mask, dest_idx);
}
torch::Tensor moe_cumsum(torch::Tensor mask) {
CHECK_INPUT(mask);
return cumsum_sub_one_in_dim0(mask);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0");
m.def("dispatch_forward", &moe_dispatch_forward,
"Forward operation in MoE dispatch function");
m.def("dispatch_backward", &moe_dispatch_backward,
"Backward operation in MoE dispatch function");
m.def("combine_forward", &moe_combine_forward,
"Combine operation in MoE combine function");
m.def("combine_backward", &moe_combine_backward,
"Combine operation in MoE combine function");
}

View File

@ -0,0 +1,49 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
#include <torch/extension.h>
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale);
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd, float momentum, float dampening, float lr,
bool nesterov, bool first_run,
bool wd_after_momentum, float scale);
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction, const float weight_decay,
const float div_scale);
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int bias_correction,
const float weight_decay, const int grad_averaging,
const int mode, at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
"Computes and apply update for LAMB optimizer");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
}

View File

@ -0,0 +1,70 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
float scale_factor);
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
int attn_heads);
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
}
torch::Tensor bwd(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results, float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches,
attn_heads);
}
} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::
get_batch_per_block,
"Return Batch per block size.");
}

View File

@ -0,0 +1,54 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results, float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}

View File

@ -0,0 +1,426 @@
/*This code from VLLM:
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
* with minor changes. */
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <stdio.h>
#include "block_reduce.h"
#include "../common/micros.h"
#include "funcs/cast_functor.h"
#include "funcs/op_functor.h"
using colossalAI::cuda::utils::block_reduce;
using colossalAI::cuda::utils::ReduceType;
using colossalAI::cuda::funcs::TypeConverter;
using colossalAI::cuda::funcs::CastFunctor;
using colossalAI::cuda::funcs::BinaryOpFunctor;
using colossalAI::cuda::funcs::BinaryOpType;
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
if (DATA_SIZE == 2) { \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
} else { \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
general_##__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
} \
// optimized for half and bf16
template<typename scalar_t, int unroll_factor>
__global__ void rms_layernorm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
using scalar2_t = typename TypeConverter<scalar_t>::Type;
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t;
__shared__ float s_variance;
/*
* since the open-sourced LLM's hidden dimensions mainly range from
* 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported
* hidden dimension limit to 8192, and each thread's capacity
* for caching input tensors to 8 (8192 = 8 * 1024) which
* will cause problems for extremely large models, such as
* Megatron-Turing NLG 530B with hidden dimensions up to 20480
*/
scalar2_t x_local[4];
scalar2_t* out_ptr = (scalar2_t*)out;
const scalar2_t* input_ptr = (scalar2_t*)input;
const scalar2_t* weight_ptr = (const scalar2_t*)weight;
float variance = 0.0f;
int row_offset = blockIdx.x * hidden_size / 2;
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input_ptr[id];
float v1 = CastFunctor<scalar_t,float>()(x_local[cnt].x);
float v2 = CastFunctor<scalar_t,float>()(x_local[cnt].y);
variance += v1 * v1 + v2 * v2;
}
block_reduce<float, ReduceType::kSum,1>(&variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
scalar2_t s_variance_2 = CastFunctor<float,scalar2_t>()(s_variance);
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
out_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
}
}
template<typename scalar_t, int unroll_factor>
__global__ void general_rms_layernorm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
float x_local[8];
int row_offset = blockIdx.x * hidden_size;
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = (float) input[id];
variance += x_local[cnt] * x_local[cnt];
}
block_reduce<float, ReduceType::kSum,1>(&variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
}
}
// optimized for half and bf16
template<typename scalar_t, int unroll_factor>
__global__ void fused_add_rms_layernorm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
using scalar2_t = typename TypeConverter<scalar_t>::Type;
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kAdd> add_scalar2t;
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t;
__shared__ float s_variance;
scalar2_t x_local[4];
scalar2_t* input_ptr = (scalar2_t*)input;
scalar2_t* residual_ptr = (scalar2_t*)residual;
const scalar2_t* weight_ptr = (const scalar2_t*)weight;
float variance = 0.0f;
int row_offset = blockIdx.x * hidden_size / 2;
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input_ptr[id];
x_local[cnt] = add_scalar2t(x_local[cnt], residual_ptr[id]);
float v1 = CastFunctor<scalar_t,float>()(x_local[cnt].x);
float v2 = CastFunctor<scalar_t,float>()(x_local[cnt].y);
variance += v1 * v1 + v2 * v2;
residual_ptr[id] = x_local[cnt];
}
block_reduce<float, ReduceType::kSum,1>(&variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
scalar2_t s_variance_2 = CastFunctor<float, scalar2_t>()(s_variance);
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
input_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
}
}
template<typename scalar_t, int unroll_factor>
__global__ void general_fused_add_rms_layernorm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
float x_local[8];
int row_offset = blockIdx.x * hidden_size;
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = (float) input[id];
x_local[cnt] += (float) residual[id];
variance += x_local[cnt] * x_local[cnt];
residual[id] = (scalar_t) x_local[cnt];
}
block_reduce<float, ReduceType::kSum,1>(&variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
}
}
void rms_layernorm(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_tokens >= 512) {
if (input.scalar_type() == at::ScalarType::Float) {
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
} else {
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
}
} else {
int unroll_factor = (hidden_size + block.x - 1) / block.x;
if (input.scalar_type() != at::ScalarType::Float) {
block.x = std::min(hidden_size / 2, 1024);
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
}
switch (unroll_factor) {
case 1:
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 2:
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 4:
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 8:
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
default:
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
}
}
}
void fused_add_rms_layernorm(
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_tokens >= 512) {
if (input.scalar_type() == at::ScalarType::Float) {
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
} else {
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
}
} else {
int unroll_factor = (hidden_size + block.x - 1) / block.x;
if (input.scalar_type() != at::ScalarType::Float) {
block.x = std::min(hidden_size / 2, 1024);
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
}
switch (unroll_factor) {
case 1:
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 2:
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 4:
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
case 8:
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);)
break;
default:
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
}
}
}

View File

@ -0,0 +1,500 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#pragma once
#include <assert.h>
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include "utils/vector_copy_utils.h"
namespace {
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template <typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
template <typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional
* features 1) input scaling 2) Explicit masking
*/
template <typename input_t, typename output_t, typename acc_t,
int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale,
int micro_batch_size, int element_count, int pad_batches) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch =
(blockDim.y *
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
threadIdx.y) *
WARP_BATCH;
int pad_first_batch = 0;
if (pad_batches != 1) { // bert style
pad_first_batch =
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
WARP_BATCH;
} else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i * element_count + it * WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -10000.0;
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t,
int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
int micro_batch_size, int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset =
first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_grad, grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_output, output + i * element_count + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] =
(acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] =
(output_t)(scale * (grad_reg[i][it + element] -
output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
gradInput + i * element_count + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
int attn_heads) {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
return batches_per_block;
}
template <typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
const uint8_t *mask,
const input_t scale,
int query_seq_len, int key_seq_len,
int batches, int attn_heads,
int pad_batches) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_forward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0);
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
}
}
}
template <typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int query_seq_len, int key_seq_len,
int batches, int attn_heads) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}

View File

@ -0,0 +1,89 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "../common/micros.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
int attn_heads) {
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
float scale_factor) {
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len,
// seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results = torch::empty(
{batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(), "dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr), scale_factor,
query_seq_len, key_seq_len, batches, attn_heads, pad_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len,
// seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, query_seq_len, key_seq_len, batches, attn_heads););
// backward pass is completely in-place
return output_grads;
}
} // namespace scaled_masked_softmax
} // namespace fused_softmax
} // namespace multihead_attn

View File

@ -0,0 +1,538 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#pragma once
#include <assert.h>
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <cfloat>
#include <limits>
#include "utils/vector_copy_utils.h"
namespace {
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template <typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
template <typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional
* features 1) input scaling 2) Implicit time (diagonal masking)
*/
template <typename input_t, typename output_t, typename acc_t,
int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size,
int stride, int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit =
(local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_data, src + i * element_count * stride + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count * stride + it * WARP_SIZE);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t,
int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
int micro_batch_size, int stride, int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
blockIdx.x;
int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_output, output + i * element_count * stride + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] =
(acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] =
(output_t)(scale * (grad_reg[i][it + element] -
output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
gradInput + i * element_count * stride + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
template <typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst, const input_t *src, const input_t scale,
int softmax_elements, int softmax_elements_stride, int attn_batches) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_forward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
default:
break;
}
}
}
template <typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input, input_t *grad, const input_t *output,
const acc_t scale, int softmax_elements, int softmax_elements_stride,
int attn_batches) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}

View File

@ -0,0 +1,75 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "../common/micros.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t,
float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr), scale_factor, seq_len,
seq_len, attn_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 3d tensor with dimensions [attn_batches, seq_len,
// seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t,
float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, seq_len, seq_len, attn_batches););
// backward pass is completely in-place
return output_grads;
}
} // namespace scaled_upper_triang_masked_softmax
} // namespace fused_softmax
} // namespace multihead_attn

View File

@ -0,0 +1,78 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "nvgpu_dev_info.h"
namespace colossalAI {
namespace cuda {
namespace utils {
struct GPULaunchConfig {
dim3 block{1, 1, 1};
dim3 grid{1, 1, 1};
};
static GPULaunchConfig GetGPULaunchConfig1D(const NVGPUDevInfo& dev_info,
int64_t numel, int64_t vec_size) {
const int64_t max_threads_per_block = dev_info.GetMaxThreadsPerBlock();
const int64_t max_blocks_per_grid = dev_info.GetMaxGridDims()[0];
const int64_t kMinimumSize = 64;
const int64_t kMaximumSize = 512;
int64_t active_threads = (numel + vec_size - 1) / vec_size;
int64_t sm_num = dev_info.GetMultiProcessorCount();
// Note(LiuYang): expected threads should be in [64, 128, 256, 512] generally
int64_t expected_threads_per_block = kMaximumSize;
auto RoundUpToPowerOfTwo = [](int64_t x) {
bool is_power_of_two = false;
int64_t ret = 1;
int64_t y = x;
while (y > 0) {
is_power_of_two = ((ret ^ x) == 0);
y = (x >> 1);
ret = (ret << 1);
if (y > 0) is_power_of_two = false;
}
if (is_power_of_two) return x;
return ret;
};
if ((active_threads / (sm_num << 1)) < max_threads_per_block) {
expected_threads_per_block =
RoundUpToPowerOfTwo(active_threads / (sm_num << 1));
} else if ((active_threads / (sm_num << 2)) < max_threads_per_block) {
expected_threads_per_block =
RoundUpToPowerOfTwo(active_threads / (sm_num << 2));
}
expected_threads_per_block =
std::max(expected_threads_per_block, kMinimumSize);
int64_t expect_block_per_grid =
((active_threads + expected_threads_per_block - 1) /
expected_threads_per_block);
if (expect_block_per_grid > max_blocks_per_grid) {
expect_block_per_grid = max_blocks_per_grid;
expected_threads_per_block =
(active_threads + expect_block_per_grid - 1) / expect_block_per_grid;
if (expected_threads_per_block > max_threads_per_block)
throw std::invalid_argument(
"Threads required for current input exceed for current GPU!");
expected_threads_per_block =
RoundUpToPowerOfTwo(expected_threads_per_block);
expect_block_per_grid = ((active_threads + expected_threads_per_block - 1) /
expected_threads_per_block);
}
GPULaunchConfig config;
config.block.x = expected_threads_per_block;
config.grid.x = expect_block_per_grid;
return config;
}
} // namespace utils
} // namespace cuda
} // namespace colossalAI

View File

@ -0,0 +1,18 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <exception>
#define CUDA_CHECK(func) \
{ \
auto status = func; \
if (status != cudaSuccess) { \
throw std::runtime_error(cudaGetErrorString(status)); \
} \
}
#define HOST __host__
#define DEVICE __device__
#define HOSTDEVICE __host__ __device__

View File

@ -0,0 +1,60 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <ostream>
#include <string>
#include <vector>
#include "micros.h"
namespace colossalAI {
namespace cuda {
namespace utils {
class NVGPUDevInfo {
public:
explicit NVGPUDevInfo(int device_num) : device_num_(device_num) {
CUDA_CHECK(cudaGetDeviceProperties(&prop_, device_num));
}
std::array<int, 3> GetMaxGridDims() const {
std::array<int, 3> ret;
ret[0] = prop_.maxGridSize[0];
ret[1] = prop_.maxGridSize[1];
ret[2] = prop_.maxGridSize[2];
return ret;
}
std::array<int, 3> GetMaxBlockDims() const {
std::array<int, 3> ret;
ret[0] = prop_.maxThreadsDim[0];
ret[1] = prop_.maxThreadsDim[1];
ret[2] = prop_.maxThreadsDim[2];
return ret;
}
std::array<int, 2> GetCapability() const {
std::array<int, 2> ret;
ret[0] = prop_.major;
ret[1] = prop_.minor;
return ret;
}
int GetMultiProcessorCount() const { return prop_.multiProcessorCount; }
int GetMaxThreadsPerMultiProcessor() const {
return prop_.maxThreadsPerMultiProcessor;
}
int GetMaxThreadsPerBlock() const { return prop_.maxThreadsPerBlock; }
private:
int device_num_;
cudaDeviceProp prop_;
};
} // namespace utils
} // namespace cuda
} // namespace colossalAI

View File

@ -0,0 +1,83 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <cfloat>
namespace colossalAI {
namespace cuda {
namespace utils {
template <typename T, int VecSize>
struct VecTypeTrait {};
template <typename T>
struct VecTypeTrait<T, 1> {
using Type = T;
};
template <>
struct VecTypeTrait<c10::BFloat16, 2> {
using Type = float;
};
template <>
struct VecTypeTrait<c10::BFloat16, 4> {
using Type = float2;
};
template <>
struct VecTypeTrait<c10::BFloat16, 8> {
using Type = float4;
};
template <>
struct VecTypeTrait<c10::Half, 2> {
using Type = float;
};
template <>
struct VecTypeTrait<c10::Half, 4> {
using Type = float2;
};
template <>
struct VecTypeTrait<c10::Half, 8> {
using Type = float4;
};
template <>
struct VecTypeTrait<float, 2> {
using Type = float2;
};
template <>
struct VecTypeTrait<float, 4> {
using Type = float4;
};
template <>
struct VecTypeTrait<float, 8> {
using Type = float4;
};
template <>
struct VecTypeTrait<uint8_t, 2> {
using Type = half;
};
template <>
struct VecTypeTrait<uint8_t, 4> {
using Type = half2;
};
template <>
struct VecTypeTrait<uint8_t, 8> {
using Type = float2;
};
} // namespace utils
} // namespace cuda
} // namespace colossalAI

View File

@ -0,0 +1,52 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include "vec_type_traits.h"
template <typename T, int VecSize>
__device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
}
template <>
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
// Since the maximum memory alignment length is 128 bits, we choose float4
// here.
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<const float4 *>(src));
*(reinterpret_cast<float4 *>(dst + 4)) =
*(reinterpret_cast<const float4 *>(src + 4));
}
template <typename T, int VecSize>
__device__ __inline__ void copy_zero_vector(T *dst) {
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = {0.0};
}
template <typename T>
int get_vec_size(const torch::Tensor &tensor) {
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
const int max_aligned_size = 128;
const int dtype_size = sizeof(T) * 8;
const int vec_size = max_aligned_size / sizeof(T) / 8;
// Note(LiuYang): Performance of situation of which
// vec_size equals to 8 need to be profiled in the future
// if (address % (dtype_size * 8) == 0) {
// return std::min(8, vec_size);
// }
if (address % (dtype_size * 4) == 0) {
return std::min(4, vec_size);
} else if (address % (dtype_size * 2) == 0) {
return std::min(2, vec_size);
} else {
return 1;
}
}

View File

@ -0,0 +1,190 @@
# This code from NVIDIA Megatron:
# with minor changes.
import enum
import torch
import torch.nn as nn
from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader
try:
from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax
except ImportError:
scaled_masked_softmax = None
scaled_upper_triang_masked_softmax = None
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
global scaled_upper_triang_masked_softmax
if scaled_upper_triang_masked_softmax:
scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load()
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
scale_t = torch.tensor([scale])
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None, None
class FusedScaleMaskSoftmax(nn.Module):
"""
Fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: Flag to indicate if input in fp16 data format.
input_in_bf16: Flag to indicate if input in bf16 data format.
attn_mask_type: Attention mask type (pad or causal)
scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion
mask_func: Mask function to be applied.
softmax_in_fp32: If True, softmax in performed at fp32 precision.
scale: Scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
def get_batch_per_block(self, sq, sk, b, np):
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)

View File

@ -0,0 +1,446 @@
/*
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#include "cpu_adam.h"
#include <math.h>
#include <omp.h>
#include <string.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
// C++ interface
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
}
AVX_Data momentum_4;
this->simd_load(momentum_half_precision, _exp_avg + i,
momentum_cast_h + i, momentum_4);
AVX_Data variance_4;
this->simd_load(variance_half_precision, _exp_avg_sq + i,
variance_cast_h + i, variance_4);
AVX_Data param_4;
this->simd_load(param_half_precision, _params + i, params_cast_h + i,
param_4);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
}
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data =
SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
variance_4.data =
SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
if (_weight_decay > 0 && _adamw_mode) {
param_4.data =
SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data);
}
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
this->simd_store(param_half_precision, _params + i, params_cast_h + i,
param_4);
this->simd_store(momentum_half_precision, _exp_avg + i,
momentum_cast_h + i, momentum_4);
this->simd_store(variance_half_precision, _exp_avg_sq + i,
variance_cast_h + i, variance_4);
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k];
if (loss_scale > 0) {
grad /= loss_scale;
}
float param =
param_half_precision ? (float)params_cast_h[k] : _params[k];
float momentum =
momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
float variance = variance_half_precision ? (float)variance_cast_h[k]
: _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) {
param += w_decay * param;
}
param = grad * step_size + param;
if (param_half_precision)
params_cast_h[k] = (__half)param;
else
_params[k] = param;
if (momentum_half_precision)
momentum_cast_h[k] = (__half)(momentum);
else
_exp_avg[k] = momentum;
if (variance_half_precision)
variance_cast_h[k] = (__half)(variance);
else
_exp_avg_sq[k] = variance;
}
}
}
}
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
float betta2_minus1 = 1 - _betta2;
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
AVX_Data grad_4[4];
AVX_Data momentum_4[4];
AVX_Data variance_4[4];
AVX_Data param_4[4];
#pragma unroll 4
for (int j = 0; j < 4; j++) {
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_load(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data =
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data =
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_store(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
}
}
}
#endif
if (_param_size > rounded_size)
Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size)
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
float betta2_minus1 = 1 - _betta2;
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
AVX_Data grad_4[8];
AVX_Data momentum_4[8];
AVX_Data variance_4[8];
AVX_Data param_4[8];
#pragma unroll 8
for (int j = 0; j < 8; j++) {
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_load(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data =
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data =
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_store(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
}
}
}
#endif
if (_param_size > rounded_size)
Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size)
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float *params_ptr = (float *)params_c.data_ptr();
float *grads_ptr = (float *)grads_c.data_ptr();
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
this->IncrementStep(step, beta1, beta2);
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf),
(exp_avg.options().dtype() == at::kHalf),
(exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
}
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<Adam_Optimizer>(m, "CPUAdamOptimizer")
.def(py::init<float, float, float, float, float, bool>())
.def("step", &Adam_Optimizer::step);
}

View File

@ -0,0 +1,185 @@
/*
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <torch/extension.h>
#if (__x86_64__ || __i386__)
#include <cpuid.h>
#include <x86intrin.h>
#endif
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
#if defined(__AVX512__)
#define SIMD_WIDTH 16
#define INTV __m256i
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__)
#define SIMD_WIDTH 8
#define INTV __m128i
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
#define SIMD_SET(x) _mm256_set1_ps(x)
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))
#endif
union AVX_Data {
#if defined(__AVX512__)
__m512 data;
#elif defined(__AVX256__) or defined(__AVX2__)
__m256 data;
#endif
// float data_f[16];
};
#endif
#define STEP(SPAN) \
void Step_##SPAN( \
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
size_t _param_size, bool param_half_precision = false, \
bool grad_half_precision = false, bool momentum_half_precision = false, \
bool variance_half_precision = false, float loss_scale = -1);
class Adam_Optimizer {
public:
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {}
~Adam_Optimizer() {}
STEP(1)
STEP(4)
STEP(8)
inline void IncrementStep(size_t step, float beta1, float beta2) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
inline void update_state(float lr, float epsilon, float weight_decay,
bool bias_correction) {
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;
_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
data.data = SIMD_LOAD_HALF(h_ptr);
} else {
data.data = SIMD_LOAD(ptr);
}
}
inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
SIMD_STORE_HALF(h_ptr, data.data);
} else {
SIMD_STORE(ptr, data.data);
}
}
#endif
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
private:
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
size_t _step;
float _bias_correction1;
float _bias_correction2;
bool _adamw_mode;
};

View File

@ -0,0 +1,109 @@
import os
import time
from abc import abstractmethod
from pathlib import Path
from typing import List
from .base_extension import _Extension
from .cpp_extension import _CppExtension
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
__all__ = ["_CudaExtension"]
# Some constants for installation checks
MIN_PYTORCH_VERSION_MAJOR = 1
MIN_PYTORCH_VERSION_MINOR = 10
class _CudaExtension(_CppExtension):
@abstractmethod
def nvcc_flags(self) -> List[str]:
"""
This function should return a list of nvcc compilation flags for extensions.
"""
def is_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_compatible(self) -> None:
from torch.utils.cpp_extension import CUDA_HOME
if not CUDA_HOME:
raise AssertionError(
"[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions"
)
check_system_pytorch_cuda_match(CUDA_HOME)
check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR)
def get_cuda_home_include(self):
"""
return include path inside the cuda home.
"""
from torch.utils.cpp_extension import CUDA_HOME
if CUDA_HOME is None:
raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
cuda_include = os.path.join(CUDA_HOME, "include")
return cuda_include
def build_jit(self) -> None:
from torch.utils.cpp_extension import CUDA_HOME, load
set_cuda_arch_list(CUDA_HOME)
# get build dir
build_directory = _Extension.get_jit_extension_folder_path()
build_directory = Path(build_directory)
build_directory.mkdir(parents=True, exist_ok=True)
# check if the kernel has been built
compiled_before = False
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
if kernel_file_path.exists():
compiled_before = True
# load the kernel
if compiled_before:
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
else:
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
build_start = time.time()
op_kernel = load(
name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_cuda_cflags=self.nvcc_flags(),
extra_ldflags=[],
build_directory=str(build_directory),
)
build_duration = time.time() - build_start
if compiled_before:
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
else:
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
return op_kernel
def build_aot(self) -> "CUDAExtension":
from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension
set_cuda_arch_list(CUDA_HOME)
return CUDAExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args={
"cxx": self.strip_empty_entries(self.cxx_flags()),
"nvcc": self.strip_empty_entries(self.nvcc_flags()),
},
)

View File

@ -0,0 +1,14 @@
from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension
from .flash_attention_npu import FlashAttentionNpuExtension
from .flash_attention_sdpa_cuda import FlashAttentionSdpaCudaExtension
try:
# TODO: remove this after updating openmoe example
import flash_attention # noqa
HAS_FLASH_ATTN = True
except:
HAS_FLASH_ATTN = False
__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension"]

View File

@ -0,0 +1,96 @@
from ..base_extension import _Extension
class FlashAttentionDaoCudaExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10)
def is_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func # noqa
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError(
"We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'."
)
def build_jit(self) -> None:
raise NotImplementedError(
"We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
)
def load(self):
from typing import Optional
import torch
from einops import rearrange
from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func
from flash_attn.bert_padding import index_first_axis, pad_input
def _unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor):
return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
):
# [B, N, S, D] -> [B, S, N, D]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
b, s_q = q.shape[:2]
if cu_seqlens_q is not None:
# padded / padded causal
# unpad input: [B, S, N, D] -> [T, N, D]
q = _unpad_input(q, q_indices)
kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices)
attn_output = flash_attn_varlen_kvpacked_func(
q,
kv,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
)
# pad output: [T, N, D] -> [B, S, N, D]
attn_output = pad_input(attn_output, q_indices, b, s_q)
else:
# causal / no attn mask
attn_output = flash_attn_func(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
)
# [B, S, N, D] -> [B, N, S, D]
return attn_output.transpose(1, 2)
return flash_attention

View File

@ -0,0 +1,62 @@
from ..base_extension import _Extension
class FlashAttentionNpuExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False)
def is_available(self) -> bool:
try:
import torch_npu
return hasattr(torch_npu, "npu_fusion_attention")
except:
return False
def assert_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError(
"Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu."
)
def build_jit(self) -> None:
raise NotImplementedError(
"Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu."
)
def load(self):
from typing import Optional
import torch
import torch_npu
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
):
num_heads = q.size(1)
return torch_npu.npu_fusion_attention(
q,
k,
v,
num_heads,
"BNSD",
atten_mask=attention_mask.bool(),
scale=scale,
keep_prob=1 - dropout_p,
)[0]
return flash_attention

View File

@ -0,0 +1,56 @@
from ..base_extension import _Extension
class FlashAttentionSdpaCudaExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_sdpa_cuda", support_aot=False, support_jit=False)
def is_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError("Flash attention SDPA does not require ahead-of-time compilation.")
def build_jit(self) -> None:
raise NotImplementedError("Flash attention SDPA does not require just-in-time compilation.")
def load(self):
from typing import Optional
import torch
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
):
return torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attention_mask,
dropout_p=dropout_p,
scale=scale,
)
return flash_attention

View File

@ -0,0 +1,3 @@
from .inference_ops_cuda import InferenceOpsCudaExtension
__all__ = ["InferenceOpsCudaExtension"]

View File

@ -0,0 +1,35 @@
from ..cuda_extension import _CudaExtension
from ..utils import get_cuda_cc_flag
class InferenceOpsCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="inference_ops_cuda")
def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in [
"cuda/pybind/inference.cpp",
"cuda/decode_kv_cache_memcpy_kernel.cu",
"cuda/context_kv_cache_memcpy_kernel.cu",
"cuda/fused_rotary_emb_and_cache_kernel.cu",
"cuda/activation_kernel.cu",
"cuda/rms_layernorm_kernel.cu",
"cuda/get_cos_and_sin_kernel.cu",
]
]
return ret
def include_dirs(self):
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
return ret
def cxx_flags(self):
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
return ["-O3"] + version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ["-lineinfo"]
extra_cuda_flags.extend(get_cuda_cc_flag())
return ["-O3", "--use_fast_math"] + extra_cuda_flags

View File

@ -0,0 +1,3 @@
from .layernorm_cuda import LayerNormCudaExtension
__all__ = ["LayerNormCudaExtension"]

View File

@ -0,0 +1,24 @@
from ..cuda_extension import _CudaExtension
from ..utils import append_nvcc_threads, get_cuda_cc_flag
class LayerNormCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="layernorm_cuda")
def sources_files(self):
ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/layer_norm.cpp", "cuda/layer_norm_kernel.cu"]]
return ret
def include_dirs(self):
ret = [self.get_cuda_home_include()]
return ret
def cxx_flags(self):
return ["-O3"] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ["-maxrregcount=50"]
extra_cuda_flags.extend(get_cuda_cc_flag())
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros
return append_nvcc_threads(ret)

View File

@ -0,0 +1,3 @@
from .moe_cuda import MoeCudaExtension
__all__ = ["MoeCudaExtension"]

View File

@ -0,0 +1,29 @@
from ..cuda_extension import _CudaExtension
from ..utils import append_nvcc_threads, get_cuda_cc_flag
class MoeCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="moe_cuda")
def include_dirs(self):
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
return ret
def sources_files(self):
ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/moe.cpp", "cuda/moe_kernel.cu"]]
return ret
def cxx_flags(self):
return ["-O3"] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]
extra_cuda_flags.extend(get_cuda_cc_flag())
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -0,0 +1,3 @@
from .fused_optimizer_cuda import FusedOptimizerCudaExtension
__all__ = ["FusedOptimizerCudaExtension"]

View File

@ -0,0 +1,34 @@
from ..cuda_extension import _CudaExtension
from ..utils import get_cuda_cc_flag
class FusedOptimizerCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="fused_optim_cuda")
def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in [
"cuda/pybind/optimizer.cpp",
"cuda/multi_tensor_sgd_kernel.cu",
"cuda/multi_tensor_scale_kernel.cu",
"cuda/multi_tensor_adam_kernel.cu",
"cuda/multi_tensor_l2norm_kernel.cu",
"cuda/multi_tensor_lamb_kernel.cu",
]
]
return ret
def include_dirs(self):
ret = [self.get_cuda_home_include()]
return ret
def cxx_flags(self):
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
return ["-O3"] + version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ["-lineinfo"]
extra_cuda_flags.extend(get_cuda_cc_flag())
return ["-O3", "--use_fast_math"] + extra_cuda_flags

View File

@ -0,0 +1,4 @@
from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension
from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension
__all__ = ["ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension"]

View File

@ -0,0 +1,32 @@
from ..cuda_extension import _CudaExtension
from ..utils import append_nvcc_threads
class ScaledMaskedSoftmaxCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="scaled_masked_softmax_cuda")
def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in ["cuda/pybind/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_kernel.cu"]
]
return ret
def include_dirs(self):
return [self.get_cuda_home_include()]
def cxx_flags(self):
return ["-O3"] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = [
"-std=c++14",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
]
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -0,0 +1,34 @@
from ..cuda_extension import _CudaExtension
from ..utils import append_nvcc_threads, get_cuda_cc_flag
class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="scaled_upper_triangle_masked_softmax_cuda")
def include_dirs(self):
return [self.get_cuda_home_include()]
def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in [
"cuda/pybind/scaled_upper_triang_masked_softmax.cpp",
"cuda/scaled_upper_triang_masked_softmax_kernel.cu",
]
]
return ret
def cxx_flags(self):
return ["-O3"] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]
extra_cuda_flags.extend(get_cuda_cc_flag())
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -0,0 +1,21 @@
from .base_extension import _Extension
__all__ = ["_TritonExtension"]
class _TritonExtension(_Extension):
def __init__(self, name: str, priority: int = 1):
super().__init__(name, support_aot=False, support_jit=True, priority=priority)
def is_hardware_compatible(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def load(self):
return self.build_jit()

View File

@ -0,0 +1,229 @@
import os
import re
import subprocess
import warnings
from typing import List
def print_rank_0(message: str) -> None:
"""
Print on only one process to avoid spamming.
"""
try:
import torch.distributed as dist
if not dist.is_initialized():
is_main_rank = True
else:
is_main_rank = dist.get_rank() == 0
except ImportError:
is_main_rank = True
if is_main_rank:
print(message)
def get_cuda_version_in_pytorch() -> List[int]:
"""
This function returns the CUDA version in the PyTorch build.
Returns:
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
import torch
try:
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
except:
raise ValueError(
"[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda"
)
return torch_cuda_major, torch_cuda_minor
def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
"""
Get the System CUDA version from nvcc.
Args:
cuda_dir (str): the directory for CUDA Toolkit.
Returns:
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
nvcc_path = os.path.join(cuda_dir, "bin/nvcc")
if cuda_dir is None:
raise ValueError(
f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."
)
# check for nvcc path
if not os.path.exists(nvcc_path):
raise FileNotFoundError(
f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."
)
# parse the nvcc -v output to obtain the system cuda version
try:
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
except:
raise ValueError(
f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}"
)
return bare_metal_major, bare_metal_minor
def check_system_pytorch_cuda_match(cuda_dir):
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch()
if bare_metal_major != torch_cuda_major:
raise Exception(
f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) "
f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})."
"Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ."
)
if bare_metal_minor != torch_cuda_minor:
warnings.warn(
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
"The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "
"If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"
)
return True
def get_pytorch_version() -> List[int]:
"""
This functions finds the PyTorch version.
Returns:
A tuple of integers in the form of (major, minor, patch).
"""
import torch
torch_version = torch.__version__.split("+")[0]
TORCH_MAJOR = int(torch_version.split(".")[0])
TORCH_MINOR = int(torch_version.split(".")[1])
TORCH_PATCH = int(torch_version.split(".")[2], 16)
return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH
def check_pytorch_version(min_major_version, min_minor_version) -> bool:
"""
Compare the current PyTorch version with the minium required version.
Args:
min_major_version (int): the minimum major version of PyTorch required
min_minor_version (int): the minimum minor version of PyTorch required
Returns:
A boolean value. The value is True if the current pytorch version is acceptable and False otherwise.
"""
# get pytorch version
torch_major, torch_minor, _ = get_pytorch_version()
# if the
if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version):
raise RuntimeError(
f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n"
"The latest stable release can be obtained from https://pytorch.org/get-started/locally/"
)
def check_cuda_availability():
"""
Check if CUDA is available on the system.
Returns:
A boolean value. True if CUDA is available and False otherwise.
"""
import torch
return torch.cuda.is_available()
def set_cuda_arch_list(cuda_dir):
"""
This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation.
Ahead-of-time compilation occurs when BUILD_EXT=1 is set when running 'pip install'.
"""
cuda_available = check_cuda_availability()
# we only need to set this when CUDA is not available for cross-compilation
if not cuda_available:
warnings.warn(
"\n[extension] PyTorch did not find available GPUs on this system.\n"
"If your intention is to cross-compile, this is not an error.\n"
"By default, Colossal-AI will cross-compile for \n"
"1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"2. Volta (compute capability 7.0)\n"
"3. Turing (compute capability 7.5),\n"
"4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n"
"\nIf you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n'
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"]
if int(bare_metal_major) == 11:
if int(bare_metal_minor) == 0:
arch_list.append("8.0")
else:
arch_list.append("8.0")
arch_list.append("8.6")
arch_list_str = ";".join(arch_list)
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str
return False
return True
def get_cuda_cc_flag() -> List[str]:
"""
This function produces the cc flags for your GPU arch
Returns:
The CUDA cc flags for compilation.
"""
# only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release
import torch
cc_flag = []
max_arch = "".join(str(i) for i in torch.cuda.get_device_capability())
for arch in torch.cuda.get_arch_list():
res = re.search(r"sm_(\d+)", arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch):
cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"])
return cc_flag
def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
"""
This function appends the threads flag to your nvcc args.
Returns:
The nvcc compilation flags including the threads flag.
"""
from torch.utils.cpp_extension import CUDA_HOME
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args

View File

@ -201,6 +201,7 @@ class Linear1D_Col(ParallelModule):
if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
elif self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim

View File

@ -1 +0,0 @@
../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py

View File

@ -0,0 +1,352 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import math
from types import MethodType
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
repeat_kv,
)
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
if get_accelerator().name == "cuda":
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
from flash_attn.ops.rms_norm import rms_norm
def _prepare_decoder_attention_mask(
self: LlamaModel,
attention_mask: torch.BoolTensor,
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
past_key_values_length: int,
) -> Optional[torch.Tensor]:
"""
Decoder attetion mask
"""
if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat(
tensors=(
torch.full(
size=(input_shape[0], past_key_values_length),
fill_value=True,
dtype=attention_mask.dtype,
device=attention_mask.device,
),
attention_mask,
),
dim=-1,
) # (bsz, past_key_values_length + q_len)
if attention_mask is not None and torch.all(attention_mask):
return None # Faster
return attention_mask
def attention_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
"""
if output_attentions:
logger.warning(
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
"return `None` instead."
)
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
q_slicing, kv_slicing = (
dim // self.config.pretraining_tp
for dim in (
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
)
) # `Tuple[int, int]`
q_slices, k_slices, v_slices = (
proj.weight.split(slicing, dim=0)
for proj, slicing in (
(self.q_proj, q_slicing),
(self.k_proj, kv_slicing),
(self.v_proj, kv_slicing),
)
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
q, k, v = (
torch.cat(
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
dim=-1,
)
for slices in (q_slices, k_slices, v_slices)
)
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
else:
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
# (bsz, q_len, num_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim),
# (bsz, q_len, num_key_value_heads * head_dim)
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
q, k, v = (
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
for states, num_heads in (
(q, self.num_heads),
(k, self.num_key_value_heads),
(v, self.num_key_value_heads),
)
)
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
past_kv_len = 0
if past_key_value is not None:
# if `past_key_value` is not None, `kv_len` > `q_len`.
past_kv_len = past_key_value[0].shape[-2]
kv_len += past_kv_len
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
cos, sin = self.rotary_emb(v, seq_len=kv_len)
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
past_key_value = (k, v) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
key_padding_mask = attention_mask
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
if past_kv_len > 0:
q = torch.cat(
tensors=(
torch.full(
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
fill_value=0.0,
dtype=q.dtype,
device=q.device,
),
q,
),
dim=1,
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
if key_padding_mask is None:
# (bsz, past_kv_len + q_len, num_heads, head_dim)
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
output = rearrange(
output, pattern="... h d -> ... (h d)"
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
else:
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
kv, _, cu_kv_lens, max_kv_len = unpad_input(
hidden_states=torch.stack(tensors=(k, v), dim=2),
attention_mask=key_padding_mask,
)
output_unpad = flash_attn_varlen_kvpacked_func(
q=q,
kv=kv,
cu_seqlens_q=cu_q_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_q_len,
max_seqlen_k=max_kv_len,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)
output = pad_input(
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
indices=indices,
batch=bsz,
seqlen=past_kv_len + q_len,
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
if past_kv_len > 0:
# Strip off the zero query outputs.
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
output = self.o_proj(output) # (bsz, q_len, hidden_size)
return output, None, past_key_value
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Formard function for RMS Norm
"""
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
module.forward = MethodType(attention_forward, module)
if isinstance(module, LlamaModel):
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
if isinstance(module, LlamaRMSNorm):
module.forward = MethodType(rms_norm_forward, module)
elif get_accelerator().name == "npu":
import torch_npu
class NPULlamaAttention(LlamaAttention):
use_flash: bool = True
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.setup()
def setup(self):
self._softmax_scale = 1 / math.sqrt(self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if not self.use_flash:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
else:
attn_output, *_ = torch_npu.npu_fusion_attention(
query_states,
key_states,
value_states,
self.num_heads,
"BNSD",
atten_mask=attention_mask.bool(),
scale=self._softmax_scale,
padding_mask=None,
pre_tockens=65535,
next_tockens=0,
keep_prob=1.0,
inner_precise=0,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum(
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
)
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class NPURMSNorm(LlamaRMSNorm):
def forward(self, hidden_states):
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
module.__class__ = NPULlamaAttention
module.setup()
if isinstance(module, LlamaRMSNorm):
module.__class__ = NPURMSNorm