merge commit

pull/5339/head
FrankLeeeee 2024-01-31 10:41:47 +08:00
commit c565519913
306 changed files with 4053 additions and 9120 deletions

View File

@ -140,7 +140,7 @@ jobs:
- name: Install Colossal-AI
run: |
CUDA_EXT=1 pip install -v -e .
BUILD_EXT=1 pip install -v -e .
pip install -r requirements/requirements-test.txt
- name: Store Colossal-AI Cache
@ -160,9 +160,7 @@ jobs:
--ignore tests/test_gptq \
--ignore tests/test_infer_ops \
--ignore tests/test_legacy \
--ignore tests/test_moe \
--ignore tests/test_smoothquant \
--ignore tests/test_checkpoint_io \
tests/
env:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64

View File

@ -12,7 +12,7 @@ jobs:
if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 90
steps:
@ -23,6 +23,7 @@ jobs:
ngpu=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
endIndex=$(($ngpu-1))
for i in $(seq 0 $endIndex);
do
gpu_used=$(nvidia-smi -i $i --query-gpu=memory.used --format=csv,noheader,nounits)
[ "$gpu_used" -gt "2000" ] && avai=false
done
@ -54,7 +55,7 @@ jobs:
if: steps.check-avai.outputs.avai == 'true'
run: |
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
CUDA_EXT=1 pip install -v -e .
BUILD_EXT=1 pip install -v -e .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
pip install -r requirements/requirements-test.txt

View File

@ -45,9 +45,9 @@ jobs:
fail-fast: false
matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
timeout-minutes: 15
steps:
- name: 📚 Checkout
uses: actions/checkout@v3

View File

@ -77,7 +77,7 @@ jobs:
fail-fast: false
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 20
concurrency:

View File

@ -34,7 +34,7 @@ jobs:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
timeout-minutes: 10
steps:
- name: 📚 Checkout

View File

@ -18,7 +18,7 @@ jobs:
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb
timeout-minutes: 30
defaults:

View File

@ -20,7 +20,7 @@ jobs:
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt
timeout-minutes: 30
defaults:

View File

@ -19,7 +19,7 @@ jobs:
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
volumes:
- /data/scratch/test_data_colossalqa:/data/scratch/test_data_colossalqa
- /data/scratch/llama-tiny:/data/scratch/llama-tiny

View File

@ -1,4 +1,4 @@
include *.txt README.md
recursive-include requirements *.txt
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi
recursive-include op_builder *.py
recursive-include extensions *.py *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi

View File

@ -142,7 +142,7 @@ distributed training and inference in a few lines.
[[Modelscope model weights]](https://www.modelscope.cn/models/colossalai/Colossal-LLaMA-2-13b-base/summary)
| Model | Backbone | Tokens Consumed | MMLU (5-shot) | CMMLU (5-shot)| AGIEval (5-shot) | GAOKAO (0-shot) | CEval (5-shot) |
| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-------------: | :-------------: |
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-------------: | :-------------: |
| Baichuan-7B | - | 1.2T | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 |
| Baichuan-13B-Base | - | 1.4T | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 |
| Baichuan2-7B-Base | - | 2.6T | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 |
@ -160,6 +160,7 @@ distributed training and inference in a few lines.
| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | 49.96 | 41.10 | 39.83 | 33.00 | - |
| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | 50.25 | 40.99 | 40.04 | 30.54 | - |
| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 |
| **Colossal-LLaMA-2-13b-base** | Llama-2-13B | **0.025T** | 56.42 | 61.80 | 54.69 | 69.53 | 60.3 |
### ColossalChat

View File

@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from colossalai.utils import get_current_device
from colossalai.accelerator import get_accelerator
from .base import OnPolicyTrainer
from .callbacks import Callback
@ -105,7 +105,7 @@ class PPOTrainer(OnPolicyTrainer):
self.critic_optim = critic_optim
self.offload_inference_models = offload_inference_models
self.device = get_current_device()
self.device = get_accelerator().get_current_device()
def _before_fit(
self,

View File

@ -6,7 +6,6 @@ import torch.nn as nn
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.utils import get_current_device
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy
@ -158,9 +157,19 @@ class GeminiStrategy(DDPStrategy):
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
# colossalai has changed api for get_current_device in 0.3.4 version or newer
try:
from colossalai.accelerator import get_accelerator
chunk_init_device = get_accelerator().get_current_device()
except:
from colossalai.utils import get_current_device
chunk_init_device = get_current_device()
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
chunk_init_device=get_current_device(),
chunk_init_device=chunk_init_device,
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,

View File

@ -6,12 +6,12 @@ Initialize new tokenizer for continual pre-training
"""
import argparse
import os
import json
import os
from typing import List, Union
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from colossalai.logging import get_dist_logger

View File

@ -16,7 +16,10 @@ import torch
def unwrap(model):
return model.unwrap().module
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
def neftune_post_forward_hook(module, input, output):

View File

@ -4,41 +4,34 @@
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
"""
import json
import argparse
import json
import os
import resource
from contextlib import nullcontext
from tqdm import tqdm
import torch
import torch.distributed as dist
from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from torch.utils.tensorboard import SummaryWriter
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import (
GeminiPlugin,
LowLevelZeroPlugin,
HybridParallelPlugin,
)
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossal_llama2.dataset.loader import (
load_tokenized_dataset,
setup_distributed_dataloader,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
)
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
def get_model_numel(model: torch.nn.Module) -> int:
@ -215,9 +208,18 @@ def main() -> None:
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
init_ctx = (
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
)
# colossalai has changed api for get_current_device in 0.3.4 version or newer
try:
from colossalai.accelerator import get_accelerator
current_device = get_accelerator().get_current_device()
except:
from colossalai.utils import get_current_device
current_device = get_current_device()
init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
with init_ctx:
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
# Freeze part of parameters.
@ -320,7 +322,7 @@ def main() -> None:
initial=start_step,
) as pbar:
for step, batch in pbar:
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch)
@ -372,9 +374,7 @@ def main() -> None:
# Final save.
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")

View File

@ -136,6 +136,19 @@ class ColossalLLM(LLM):
"""Get the identifying parameters."""
return {"n": self.n}
def get_token_ids(self, text: str) -> List[int]:
"""Return the ordered ids of the tokens in a text.
Args:
text: The string input to tokenize.
Returns:
A list of ids corresponding to the tokens in the text, in order they occur
in the text.
"""
# use the colossal llm's tokenizer instead of langchain's cached GPT2 tokenizer
return self.api.tokenizer.encode(text)
class VllmLLM(LLM):
"""

View File

@ -1,4 +1,5 @@
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
from . import accelerator
try:
# .version will be created by setup.py

View File

@ -0,0 +1,20 @@
# 🚀 Accelerator
## 🔗 Table of Contents
- [🚀 Accelerator](#-accelerator)
- [🔗 Table of Contents](#-table-of-contents)
- [📚 Introduction](#-introduction)
- [📌 Design and Acknowledgement](#-design-and-acknowledgement)
## 📚 Introduction
This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `auto_set_accelerator()` API.
## 📌 Design and Acknowledgement
Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https://www.deepspeed.ai/tutorials/accelerator-abstraction-interface/). We found that it is a very well-designed and well-structured module that can be easily integrated into our project. We would like to thank the DeepSpeed team for their great work.
We implemented this accelerator module from scratch. At the same time, we have implemented our own modifications:
1. we updated the accelerator API names to be aligned with PyTorch's native API names.
2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled.

View File

@ -0,0 +1,15 @@
from .api import auto_set_accelerator, get_accelerator, set_accelerator
from .base_accelerator import BaseAccelerator
from .cpu_accelerator import CpuAccelerator
from .cuda_accelerator import CudaAccelerator
from .npu_accelerator import NpuAccelerator
__all__ = [
"get_accelerator",
"set_accelerator",
"auto_set_accelerator",
"BaseAccelerator",
"CudaAccelerator",
"NpuAccelerator",
"CpuAccelerator",
]

View File

@ -0,0 +1,71 @@
#!/usr/bin/env python
from collections import OrderedDict
from typing import Union
from .base_accelerator import BaseAccelerator
from .cpu_accelerator import CpuAccelerator
from .cuda_accelerator import CudaAccelerator
from .npu_accelerator import NpuAccelerator
__all__ = ["set_accelerator", "auto_set_accelerator", "get_accelerator"]
_ACCELERATOR = None
# we use ordered dictionary here to associate the
# order with device check priority
# i.e. auto_set_accelerator will check cuda first
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator)
def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
"""
Set the global accelerator for the current process.
Args:
accelerator (Union[str, BaseAccelerator]): the type of accelerator to which the current device belongs.
"""
global _ACCELERATOR
if isinstance(accelerator, str):
_ACCELERATOR = _ACCELERATOR_MAPPING[accelerator]()
elif isinstance(accelerator, BaseAccelerator):
_ACCELERATOR = accelerator
else:
raise TypeError("accelerator must be either a string or an instance of BaseAccelerator")
def auto_set_accelerator() -> None:
"""
Automatically check if any accelerator is available.
If an accelerator is availabe, set it as the global accelerator.
"""
global _ACCELERATOR
for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items():
try:
accelerator = accelerator_cls()
if accelerator_name == "cpu" or accelerator.is_available():
_ACCELERATOR = accelerator
break
except:
pass
if _ACCELERATOR is None:
raise RuntimeError("No accelerator is available.")
def get_accelerator() -> BaseAccelerator:
"""
Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized
to the default accelerator type.
Returns: the accelerator for the current process.
"""
global _ACCELERATOR
if _ACCELERATOR is None:
auto_set_accelerator()
return _ACCELERATOR

View File

@ -0,0 +1,320 @@
#!/usr/bin/env python
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
__all__ = ["BaseAccelerator"]
class BaseAccelerator(ABC):
support_set_device: bool = True
def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None:
self._name = name
self._communication_backend = communication_backend
self._is_synchronous = is_synchronous
# =======================
# immutable attributes
# =======================
@property
def name(self) -> str:
"""
Return the name of the accelerator.
"""
return self._name
@property
def communication_backend(self) -> str:
"""
Return the name of the backend communication library.
"""
return self._communication_backend
@property
def is_synchronous(self) -> bool:
"""
Return whether the accelerator is a synchronous device.
"""
return self._is_synchronous
def __repr__(self) -> str:
cls_name = self.__class__.__name__
return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})"
# =======================
# device APIs
# =======================
@abstractmethod
def get_version(self) -> str:
"""
Return the version of the accelerator which torch is built against.
"""
@abstractmethod
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
@abstractmethod
def current_device(self) -> int:
"""
Return the current device index.
"""
@abstractmethod
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
"""
Bind the current process to a device.
"""
@abstractmethod
def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""
@abstractmethod
def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""
@abstractmethod
def is_available(self):
"""
Check if the accelerator is available.
"""
@abstractmethod
def device_count(self):
"""
Return the number of devices on the machine.
"""
def set_to_device(self, models: Any) -> Any:
"""
Send model to device.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(self.get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(self.get_current_device())
else:
return models.to(self.get_current_device())
@abstractmethod
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the capability of a device.
"""
@abstractmethod
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
@abstractmethod
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
@abstractmethod
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc.
"""
# =======================
# random number generator APIs
# =======================
@abstractmethod
def get_rng_state(self, device="cuda") -> torch.Tensor:
"""
Returns the random number generator state of the specified device as a ByteTensor.
"""
@abstractmethod
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
@abstractmethod
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
"""
Sets the random number generator state of the specified device.
"""
@abstractmethod
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
@abstractmethod
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current device.
"""
@abstractmethod
def manual_seed_all(self, seed: int) -> None:
"""
Sets the seed for generating random numbers on all devices.
"""
@abstractmethod
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current device.
"""
@abstractmethod
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all devices.
"""
@abstractmethod
def initial_seed(self) -> int:
"""
Returns the current random seed of the current device.
"""
# =======================
# memory management APIs
# =======================
@abstractmethod
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi.
"""
@abstractmethod
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of CUDA memory allocator statistics for a given device.
"""
@abstractmethod
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
@abstractmethod
def memory_snapshot(self):
"""
Returns a snapshot of the CUDA memory allocator state across all devices.
"""
@abstractmethod
def memory_allocated(self, device=None) -> int:
"""
Returns the current device memory occupied by tensors in bytes for a given device.
"""
@abstractmethod
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum device memory occupied by tensors in bytes for a given device.
"""
@abstractmethod
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum device memory occupied by tensors for a given device.
"""
@abstractmethod
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device.
"""
@abstractmethod
def memory_reserved(self, device=None) -> int:
"""
Returns the current device memory managed by the caching allocator in bytes for a given device.
"""
@abstractmethod
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum device memory managed by the caching allocator in bytes for a given device.
"""
@abstractmethod
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
@abstractmethod
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the device memory allocator.
"""
# =======================
# streams and events APIs
# =======================
@abstractmethod
def Stream(self, device=None, priority=0, **kwargs):
"""
A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
"""
@abstractmethod
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
"""
@abstractmethod
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
@abstractmethod
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
@abstractmethod
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
@abstractmethod
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
# =======================
# amp APIs
# =======================
@abstractmethod
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""

View File

@ -0,0 +1,283 @@
#!/usr/bin/env python
import resource
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import psutil
import torch
from .base_accelerator import BaseAccelerator
__all__ = ["CpuAccelerator"]
class CpuAccelerator(BaseAccelerator):
support_set_device: bool = False
"""
Accelerator class for cpu.
"""
def __init__(self):
super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False)
# =======================
# device APIs
# =======================
def get_version(self) -> str:
"""
Return the version of the accelerator which torch is built against.
"""
return ""
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
return torch.device("cpu")
def current_device(self) -> int:
"""
Return the current device index.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
"""
Bind the current process to a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def is_available(self):
"""
Check if the accelerator is available.
"""
return True
def device_count(self):
"""
Return the number of devices on the machine.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the cuda capability of a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# random number generator APIs
# =======================
def get_rng_state(self, device=None) -> torch.Tensor:
"""
Returns the random number generator state of the specified GPU as a ByteTensor.
"""
return torch.get_rng_state(device)
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None:
"""
Sets the random number generator state of the specified GPU.
"""
torch.set_rng_state(new_state)
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current GPU.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def manual_seed_all(self, seed: int) -> None:
"""
Set the random seed for the all processes.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current GPU.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all GPUs.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def initial_seed(self) -> int:
"""
Returns the current random seed of the current GPU.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# memory management APIs
# =======================
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of CUDA memory allocator statistics for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_snapshot(self):
"""
Returns a snapshot of the CUDA memory allocator state across all devices.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_allocated(self, device=None) -> int:
"""
Returns the current GPU memory occupied by tensors in bytes for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_reserved(self, device=None) -> int:
"""
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
"""
return psutil.Process().memory_info().rss
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
"""
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
max_memory = int(psutil.virtual_memory().total * fraction)
_, hard = resource.getrlimit(resource.RLIMIT_AS)
resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard))
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the CUDA memory allocator.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# streams and events APIs
# =======================
def Stream(self, device=None, priority=0, **kwargs):
"""
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# amp APIs
# =======================
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""
return nullcontext

View File

@ -0,0 +1,282 @@
#!/usr/bin/env python
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from .base_accelerator import BaseAccelerator
__all__ = ["CudaAccelerator"]
class CudaAccelerator(BaseAccelerator):
"""
Accelerator class for Nvidia CUDA devices.
"""
def __init__(self):
super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False)
# =======================
# device APIs
# =======================
def get_version(self) -> str:
"""
Return the version of the accelerator which torch is built against.
"""
return torch.version.cuda
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
return torch.device(f"cuda:{torch.cuda.current_device()}")
def current_device(self) -> int:
"""
Return the current device index.
"""
return torch.cuda.current_device()
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
"""
Bind the current process to a device.
"""
if device is None:
if not dist.is_initialized():
raise RuntimeError("Cannot get current device when distributed is not initialized.")
device = dist.get_rank() % self.device_count()
torch.cuda.set_device(device)
def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""
return torch.cuda.get_device_name(device)
def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""
torch.cuda.synchronize(device)
def is_available(self):
"""
Check if the accelerator is available.
"""
return torch.cuda.is_available()
def device_count(self):
"""
Return the number of devices on the machine.
"""
return torch.cuda.device_count()
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the cuda capability of a device.
"""
return torch.cuda.get_device_capability(device)
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
return torch.cuda.get_device_name(device)
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
return torch.cuda.get_device_properties(device)
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
"""
return torch.cuda.utilization(device)
# =======================
# random number generator APIs
# =======================
def get_rng_state(self, device="cuda") -> torch.Tensor:
"""
Returns the random number generator state of the specified GPU as a ByteTensor.
"""
return torch.cuda.get_rng_state(device)
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
return torch.cuda.get_rng_state_all()
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
"""
Sets the random number generator state of the specified GPU.
"""
torch.cuda.set_rng_state(new_state, device)
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
torch.cuda.set_rng_state_all(new_states)
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current GPU.
"""
torch.cuda.manual_seed(seed)
def manual_seed_all(self, seed: int) -> None:
"""
Set the random seed for the all processes.
"""
torch.cuda.manual_seed_all(seed)
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current GPU.
"""
torch.cuda.seed()
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all GPUs.
"""
torch.cuda.seed_all()
def initial_seed(self) -> int:
"""
Returns the current random seed of the current GPU.
"""
return torch.cuda.initial_seed()
# =======================
# memory management APIs
# =======================
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
"""
torch.cuda.empty_cache()
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of CUDA memory allocator statistics for a given device.
"""
return torch.cuda.memory_stats(device=device)
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
return torch.cuda.memory_summary(device=device, abbreviated=abbreviated)
def memory_snapshot(self):
"""
Returns a snapshot of the CUDA memory allocator state across all devices.
"""
return torch.cuda.memory_snapshot()
def memory_allocated(self, device=None) -> int:
"""
Returns the current GPU memory occupied by tensors in bytes for a given device.
"""
return torch.cuda.memory_allocated(device=device)
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
"""
return torch.cuda.max_memory_allocated(device=device)
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
"""
torch.cuda.reset_max_memory_allocated(device=device)
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
"""
torch.cuda.reset_max_memory_cached(device=device)
def memory_reserved(self, device=None) -> int:
"""
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.cuda.memory_reserved(device=device)
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.cuda.max_memory_reserved(device=device)
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
torch.cuda.set_per_process_memory_fraction(fraction, device=device)
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the CUDA memory allocator.
"""
torch.cuda.reset_peak_memory_stats(device=device)
# =======================
# streams and events APIs
# =======================
def Stream(self, device=None, priority=0, **kwargs):
"""
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
"""
return torch.cuda.Stream(device, priority, **kwargs)
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
"""
return torch.cuda.Event(enable_timing, blocking, interprocess)
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
return torch.cuda.current_stream(device)
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
return torch.cuda.default_stream(device)
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
torch.cuda.set_stream(stream_)
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
return torch.cuda.stream(stream_)
# =======================
# amp APIs
# =======================
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)

View File

@ -0,0 +1,288 @@
#!/usr/bin/env python
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from .base_accelerator import BaseAccelerator
try:
import torch_npu # noqa
except ImportError:
pass
__all__ = ["NpuAccelerator"]
class NpuAccelerator(BaseAccelerator):
"""
Accelerator class for Huawei NPU devices.
"""
def __init__(self):
super().__init__(name="npu", communication_backend="hccl", is_synchronous=False)
# =======================
# device APIs
# =======================
def get_version(self) -> str:
"""
Return the version of the accelerator which torch is built against.
"""
return torch.version.cann
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
return torch.device(f"npu:{torch.npu.current_device()}")
def current_device(self) -> int:
"""
Return the current device index.
"""
return torch.npu.current_device()
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
"""
Bind the current process to a device.
"""
if device is None:
if not dist.is_initialized():
raise RuntimeError("Cannot get current device when distributed is not initialized.")
device = dist.get_rank() % self.device_count()
torch.npu.set_device(device)
def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""
return torch.npu.get_device_name(device)
def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""
torch.npu.synchronize(device)
def is_available(self):
"""
Check if the accelerator is available.
"""
return torch.npu.is_available()
def device_count(self):
"""
Return the number of devices on the machine.
"""
return torch.npu.device_count()
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the npu capability of a device.
"""
return torch.npu.get_device_capability(device)
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
return torch.npu.get_device_name(device)
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
return torch.npu.get_device_properties(device)
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
"""
return torch.npu.utilization(device)
# =======================
# random number generator APIs
# =======================
def get_rng_state(self, device="npu") -> torch.Tensor:
"""
Returns the random number generator state of the specified GPU as a ByteTensor.
"""
return torch.npu.get_rng_state(device)
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
return torch.npu.get_rng_state_all()
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None:
"""
Sets the random number generator state of the specified GPU.
"""
torch.npu.set_rng_state(new_state, device)
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
torch.npu.set_rng_state_all(new_states)
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current GPU.
"""
torch.npu.manual_seed(seed)
def manual_seed_all(self, seed: int) -> None:
"""
Set the random seed for the all processes.
"""
torch.npu.manual_seed_all(seed)
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current GPU.
"""
torch.npu.seed()
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all GPUs.
"""
torch.npu.seed_all()
def initial_seed(self) -> int:
"""
Returns the current random seed of the current GPU.
"""
return torch.npu.initial_seed()
# =======================
# memory management APIs
# =======================
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
"""
torch.npu.empty_cache()
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of npu memory allocator statistics for a given device.
"""
return torch.npu.memory_stats(device=device)
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
return torch.npu.memory_summary(device=device, abbreviated=abbreviated)
def memory_snapshot(self):
"""
Returns a snapshot of the npu memory allocator state across all devices.
"""
return torch.npu.memory_snapshot()
def memory_allocated(self, device=None) -> int:
"""
Returns the current GPU memory occupied by tensors in bytes for a given device.
"""
return torch.npu.memory_allocated(device=device)
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
"""
return torch.npu.max_memory_allocated(device=device)
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
"""
torch.npu.reset_max_memory_allocated(device=device)
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
"""
torch.npu.reset_max_memory_cached(device=device)
def memory_reserved(self, device=None) -> int:
"""
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.npu.memory_reserved(device=device)
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.npu.max_memory_reserved(device=device)
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
torch.npu.set_per_process_memory_fraction(fraction, device=device)
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the npu memory allocator.
"""
torch.npu.reset_peak_memory_stats(device=device)
# =======================
# streams and events APIs
# =======================
def Stream(self, device=None, priority=0, **kwargs):
"""
A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details.
"""
return torch.npu.Stream(device, priority, **kwargs)
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams.
"""
return torch.npu.Event(enable_timing, blocking, interprocess)
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
return torch.npu.current_stream(device)
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
return torch.npu.default_stream(device)
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
torch.npu.set_stream(stream_)
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
return torch.npu.stream(stream_)
# =======================
# amp APIs
# =======================
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""
return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)

View File

@ -7,8 +7,8 @@ from typing import Dict
import torch
from torch import Tensor
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device
__all__ = ["BaseGradScaler"]
@ -23,7 +23,7 @@ class BaseGradScaler(ABC):
def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float)
self._verbose = verbose
if self._verbose:

View File

@ -5,7 +5,7 @@ from typing import Optional
import torch
from colossalai.utils.device import get_current_device
from colossalai.accelerator import get_accelerator
from .base_grad_scaler import BaseGradScaler
@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler):
hysteresis: int = 2,
verbose: bool = False,
):
a = get_accelerator()
a.device_count()
super().__init__(initial_scale, verbose)
if min_scale:
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
self._min_scale = torch.tensor(
[min_scale], device=get_accelerator().get_current_device(), dtype=torch.float
)
else:
self._min_scale = None
if max_scale:
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
self._max_scale = torch.tensor(
[max_scale], device=get_accelerator().get_current_device(), dtype=torch.float
)
else:
self._max_scale = None
@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler):
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].to(get_current_device())
self._scale = state_dict["scale"].to(get_accelerator().get_current_device())
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"]

View File

@ -5,8 +5,8 @@ import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base import MixedPrecisionMixin
@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
max_scale=max_scale,
)
self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
@property
def loss_scale(self) -> float:

View File

@ -4,10 +4,10 @@ from typing import Dict, Tuple
import torch
from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
from .region import Region
@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper):
hysteresis=hysteresis,
max_scale=max_scale,
)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
self._found_overflow: torch.Tensor = torch.zeros(
1, dtype=torch.int64, device=get_accelerator().get_current_device()
)
self._logger = get_dist_logger()
def _set_grad_ptr(self):

View File

@ -11,7 +11,7 @@ except:
import torch
from torch.fx.node import Node
from colossalai.utils.device import get_current_device
from colossalai.accelerator import get_accelerator
from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
@ -57,7 +57,10 @@ class Solver(ABC):
if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor
else:
self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
self.memory_budget = (
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
* self.error_factor
)
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power()

View File

@ -5,8 +5,8 @@ import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils.device import autocast
from .mixed_precision_base import MixedPrecision
@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
super().__init__(module)
def forward(self, *args, **kwargs):
with autocast():
with get_accelerator().autocast():
return self.module(*args, **kwargs)

View File

@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
@ -27,8 +28,6 @@ from colossalai.checkpoint_io.utils import (
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
@ -366,11 +365,11 @@ class GeminiPlugin(DPPluginBase):
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if IS_NPU_AVAILABLE:
if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
placement_policy=placement_policy,
enable_gradient_accumulation=enable_gradient_accumulation,
shard_param_frac=shard_param_frac,
@ -486,7 +485,10 @@ class GeminiPlugin(DPPluginBase):
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
sampler = DistributedSampler(
dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle
dataset,
num_replicas=zero_world_size * extra_dp_world_size,
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
shuffle=shuffle,
)
# Deterministic dataloader

View File

@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
@ -28,7 +29,6 @@ from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.utils.device import get_current_device
from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase
@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None:
module = module.to(self.mixed_precision)
module = module.to(get_current_device())
module = module.to(get_accelerator().get_current_device())
# setting input type cast when using mixed precision
self.convert_fn = None
@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
Returns:
None
"""
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group.
@ -346,7 +345,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if self.pp_size > 1:
@ -386,7 +387,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
# compute norm in tp process group
@ -487,7 +488,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
@ -513,7 +513,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
@ -545,7 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@ -589,7 +590,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
# compute norm in tp process group
@ -674,7 +675,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
Returns:
None
"""
# Call the superclass `_sync_grad` method to synchronize gradients.
super()._sync_grad()
@ -802,7 +802,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm = super()._compute_grad_norm(gradients, norm_type)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@ -842,7 +844,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if dp_size > 1:
# compute norm in dp process group
@ -1081,7 +1083,7 @@ class HybridParallelPlugin(PipelinePluginBase):
return True
def support_no_sync(self) -> bool:
return False
return True
def control_checkpoint_io(self) -> bool:
return True
@ -1175,9 +1177,14 @@ class HybridParallelPlugin(PipelinePluginBase):
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
):
return outputs
# Synchronize the grads of shared parameters of the model.
model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads()
@ -1241,5 +1248,8 @@ class HybridParallelPlugin(PipelinePluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (
self.zero_stage != 2
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()

View File

@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase
@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
module = module.to(get_accelerator().get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:

View File

@ -6,12 +6,12 @@ import warnings
from pathlib import Path
from typing import Dict, Union
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.context import Config
from colossalai.logging import get_dist_logger
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
from colossalai.utils import set_seed
def launch(
@ -47,17 +47,18 @@ def launch(
if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.")
if IS_NPU_AVAILABLE and backend == "nccl":
backend = "hccl"
cur_accelerator = get_accelerator()
backend = cur_accelerator.communication_backend
# init default process group
init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device
if torch.cuda.is_available() or IS_NPU_AVAILABLE:
# if local rank is not given, calculate automatically
set_device(local_rank)
if cur_accelerator.support_set_device:
cur_accelerator.set_device(local_rank)
set_seed(seed)

View File

@ -1,7 +0,0 @@
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
__all__ = [
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
]

View File

@ -1,63 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "column_remap.cuh"
#include "util.cuh"
const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;
__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
if (x_column >= x_width) return;
//if (x_row >= x_height) return;
int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;
int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;
while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}
// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}

View File

@ -1,19 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _column_remap_cuh
#define _column_remap_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
);
#endif

View File

@ -1,58 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_compat_cuh
#define _cuda_compat_cuh
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif

View File

@ -1,75 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#define _cuda_buffers_cu
#include "cuda_buffers.cuh"
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
// __constant__ half2 q4_table[16][256];
// half2 q4_table_host[16][256];
// bool q4_table_init = false;
CudaBuffers::CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
) :
device(_device),
temp_state_size(_temp_state_size),
temp_state(_temp_state),
temp_dq(_temp_dq)
{
cudaSetDevice(_device);
cudaStreamCreate(&alt_stream_1);
cudaStreamCreate(&alt_stream_2);
cudaStreamCreate(&alt_stream_3);
cudaEventCreate(&alt_stream_1_done);
cudaEventCreate(&alt_stream_2_done);
cudaEventCreate(&alt_stream_3_done);
}
CudaBuffers::~CudaBuffers()
{
cudaStreamDestroy(alt_stream_1);
cudaStreamDestroy(alt_stream_2);
cudaStreamDestroy(alt_stream_3);
cudaEventDestroy(alt_stream_1_done);
cudaEventDestroy(alt_stream_2_done);
cudaEventDestroy(alt_stream_3_done);
}
CudaBuffers* get_buffers(const int device_index)
{
return g_buffers[device_index];
}
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
)
{
CudaBuffers* buffers = new CudaBuffers
(
_device,
_temp_state_size,
_temp_state,
_temp_dq
);
g_buffers[_device] = buffers;
}
void cleanup_buffers_cuda()
{
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
{
if (!g_buffers[i]) continue;
delete g_buffers[i];
g_buffers[i] = NULL;
}
}

View File

@ -1,55 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_buffers_cuh
#define _cuda_buffers_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
const int CUDA_MAX_DEVICES = 16;
// #ifndef _cuda_buffers_cu
// extern __constant__ half2 q4_table[16][256];
// #endif
class CudaBuffers
{
public:
int device;
half* temp_state; // [max_hidden_rows * intermediate_size]
int temp_state_size;
half* temp_dq; // size of largest quant tensor * 8
cudaStream_t alt_stream_1;
cudaStream_t alt_stream_2;
cudaStream_t alt_stream_3;
cudaEvent_t alt_stream_1_done;
cudaEvent_t alt_stream_2_done;
cudaEvent_t alt_stream_3_done;
CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
~CudaBuffers();
};
CudaBuffers* get_buffers(const int device_index);
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
void cleanup_buffers_cuda();
#endif

View File

@ -1,49 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _hip_compat_cuh
#define _hip_compat_cuh
// Workaround for a bug in hipamd, backported from upstream.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}
#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
// Workaround for hipify_python using rocblas instead of hipblas.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm
#endif

View File

@ -1,254 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "q4_matrix.cuh"
#include "q4_matmul.cuh"
#include "column_remap.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
void check_cuda(cudaError_t ret)
{
switch (ret)
{
case cudaSuccess:
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
default:
printf(" **** CUDA error\n"); \
printf(" **** %s\n", cudaGetErrorString(ret)); \
TORCH_CHECK(false, "CUDA error"); \
break;
}
}
// Some decluttering macros
#define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
} while(0)
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while(0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
{
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
}
// Tuning parameters
ExLlamaTuning tuningParams;
void set_tuning_params
(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2;
}
// Release all unmanaged objects allocated by the extension
void cleanup()
{
cleanup_buffers_cuda();
g_q4_free_matrices();
}
// Prepare buffers for forward pass
void prepare_buffers
(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
)
{
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
prepare_buffers_cuda
(
device_index,
// buffer size used for sanity checks
temp_state.numel(),
(half*) temp_state.data_ptr(),
(half*) temp_dq.data_ptr()
);
}
// Create Q4Matrix, return handle
uintptr_t make_q4
(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix
(
height,
width,
groups,
(uint32_t*) qweight.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*) scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
device
);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t> (m);
}
// Matmul half @ quant -> half
void q4_matmul
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle()
);
}
}
// Remap columns in half tensor
void column_remap
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
int height = x.size(0);
int width = x.size(1);
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda
(
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
}

View File

@ -1,294 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _matrix_cuh
#define _matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
__device__ __forceinline__ half2 dot_product_8
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
half2 tmp = __hmul2(*h_ptr++, v_01);
tmp = __hfma2(*h_ptr++, v_23, tmp);
tmp = __hfma2(*h_ptr++, v_45, tmp);
tmp = __hfma2(*h_ptr++, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half* h_ptr = h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(*h_ptr++, v_0);
tmp = __hfma(*h_ptr++, v_1, tmp);
tmp = __hfma(*h_ptr++, v_2, tmp);
tmp = __hfma(*h_ptr++, v_3, tmp);
tmp = __hfma(*h_ptr++, v_4, tmp);
tmp = __hfma(*h_ptr++, v_5, tmp);
tmp = __hfma(*h_ptr++, v_6, tmp);
tmp = __hfma(*h_ptr++, v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
__device__ __forceinline__ half2 dot_product_8_x_map
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
half h_0 = h_ptr[*x_map_ptr++];
half h_1 = h_ptr[*x_map_ptr++];
half h_2 = h_ptr[*x_map_ptr++];
half h_3 = h_ptr[*x_map_ptr++];
half h_4 = h_ptr[*x_map_ptr++];
half h_5 = h_ptr[*x_map_ptr++];
half h_6 = h_ptr[*x_map_ptr++];
half h_7 = h_ptr[*x_map_ptr++];
half2 h_01 = __halves2half2(h_0, h_1);
half2 h_23 = __halves2half2(h_2, h_3);
half2 h_45 = __halves2half2(h_4, h_5);
half2 h_67 = __halves2half2(h_6, h_7);
half2 tmp = __hmul2(h_01, v_01);
tmp = __hfma2(h_23, v_23, tmp);
tmp = __hfma2(h_45, v_45, tmp);
tmp = __hfma2(h_67, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_x_map_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
#endif

View File

@ -1,260 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include "util.cuh"
#include "matrix.cuh"
#include "cu_compat.cuh"
#include "cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "hip_compat.cuh"
#endif
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
typedef void (*fp_q4_matmul_kernel)
(
const half*,
const uint32_t*,
half*,
const half*,
const uint32_t*,
const int,
const int,
const int,
const int,
const int,
const uint32_t*,
bool
);
template<bool use_half2, bool use_groupsize, bool use_x_map>
__global__ void q4_matmul_kernel
(
const half* __restrict__ x,
const uint32_t* __restrict__ w,
half* __restrict__ out,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int dim,
const int width,
const int groupsize,
const int block_size_z,
const uint32_t* __restrict__ x_map,
bool no_zero
)
{
// Start of block
int x_column = block_size_z * blockIdx.z;
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
int iterations = (x_column_end - x_column) / 8;
// Views
MatrixView_half x_(x, height, dim);
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
MatrixView_q4_column w_(w, dim, width);
MatrixView_half_rw out_(out, height, width);
// Zero output
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
{
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
__syncthreads();
}
// Loop over part of x row (and w column)
half2 acc = {};
half acc_h = {};
if constexpr (use_groupsize)
{
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
// could be slightly faster
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
{
if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
}
}
else
{
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
for (int k = x_column; k < x_column + iterations * 8; k += 8)
{
if constexpr (use_half2)
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
}
}
// Add to block result
if constexpr (use_half2)
{
half result = __hadd(__low2half(acc), __high2half(acc));
atomicAdd(out_.item_ptr(x_row, w_column), result);
}
else
{
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
}
}
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
{
// <bool use_half2, bool use_groupsize, bool use_x_map>
if (tuningParams->matmul_no_half2) {
if (block_size_z % groupsize == 0) {
if (x_map) return q4_matmul_kernel<false, true, true >;
else return q4_matmul_kernel<false, true, false>;
} else {
if (x_map) return q4_matmul_kernel<false, false, true >;
else return q4_matmul_kernel<false, false, false>;
}
} else {
if (block_size_z % groupsize == 0)
{
if (x_map) return q4_matmul_kernel<true, true, true >;
else return q4_matmul_kernel<true, true, false>;
} else {
if (x_map) return q4_matmul_kernel<true, false, true >;
else return q4_matmul_kernel<true, false, false>;
}
}
};
// Compute y = x @ w
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero,
cudaStream_t alt_stream
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
uint32_t* x_map = w->cuda_x_map;
const half* x_mapped = x;
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
{
CudaBuffers* buffers = get_buffers(w->device);
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
x_map = NULL;
}
int block_size_z;
if (w->width == 4096) block_size_z = 384; // 7B
else if (w->width == 11008) block_size_z = 256;
else if (w->width == 5120) block_size_z = 384; // 13B
else if (w->width == 13824) block_size_z = 256;
else if (w->width == 6656) block_size_z = 256; // 33B
else if (w->width == 17920) block_size_z = 128;
else block_size_z = 256;
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
dim3 threads(THREADS_X, THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height + threads.y - 1) / threads.y,
(dim + block_size_z - 1) / block_size_z
);
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
}
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
CudaBuffers* buffers = get_buffers(w->device);
const half* x_mapped = x;
if (w->cuda_x_map)
{
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
}
w->reconstruct(buffers->temp_dq);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
const float alpha = 1.0f;
const float beta = no_zero ? 1.0f : 0.0f;
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
#else
const half alpha = __float2half(1.0f);
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
#endif
}

View File

@ -1,43 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matmul_cuh
#define _q4_matmul_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh"
#include "tuning.h"
// Workaround for hipify_python using rocblas instead of hipblas.
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
#define rocblas_handle hipblasHandle_t
#endif
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero = false,
cudaStream_t alt_stream = NULL
);
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero = false
);
#endif

View File

@ -1,225 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matrix.cuh"
#include <vector>
#include "util.cuh"
#include "matrix.cuh"
using namespace std;
const int UNSHUF_BLOCKSIZE_X = 64;
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
vector<Q4Matrix*> g_q4_matrices;
void g_q4_keep_matrix(Q4Matrix* m)
{
g_q4_matrices.push_back(m);
}
void g_q4_free_matrices()
{
for (const auto& m : g_q4_matrices) delete m;
g_q4_matrices.clear();
}
Q4Matrix::Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
) :
height(_height),
width(_width),
groups(_groups),
device(_device)
{
cudaSetDevice(device);
cuda_qweight = _qweight;
cuda_qzeros = _qzeros;
cuda_scales = _scales;
groupsize = height / groups;
if (_g_idx) make_sequential(_g_idx);
}
Q4Matrix::~Q4Matrix()
{
}
// Make sequential
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint32_t* __restrict__ x_map,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int x_map_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = x_map[x_map_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Move to CUDA
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
// Rearrange rows in w
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
height / 8,
1
);
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
// Replace qweights
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
}
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ w,
half* __restrict__ out, // (y)
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int width,
const int groupsize
)
{
// Start of block
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
if (column >= width) return;
// Views
MatrixView_q4_column w_(w, height, width);
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, height / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
// Groupsize version
int group = row / groupsize;
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_read = w_.item_uint32_t(row, column);
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int s = 0; s < 32; s += 4)
{
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
*out_ptr = w_item; out_ptr += out_.width;
}
}
void Q4Matrix::reconstruct(half* out)
{
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height / 8 + threads.y - 1) / threads.y,
1
);
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
}

View File

@ -1,53 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matrix_cuh
#define _q4_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
class Q4Matrix
{
public:
int device;
int height;
int width;
int groups;
int groupsize;
uint32_t* cuda_qweight = NULL;
uint32_t* cuda_qzeros = NULL;
half* cuda_scales = NULL;
uint32_t* cuda_x_map = NULL;
Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
);
~Q4Matrix();
void reconstruct(half* out);
private:
void make_sequential(const uint32_t* cpu_g_idx);
};
void g_q4_keep_matrix(Q4Matrix* m);
void g_q4_free_matrices();
#endif

View File

@ -1,12 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _tuning_h
#define _tuning_h
struct ExLlamaTuning {
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
};
#endif

View File

@ -1,33 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase
#endif
// React to failure on return code != cudaSuccess
#define _cuda_check(fn) \
do { \
{_cuda_err = fn;} \
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
} while(false)
// React to failure on return code == 0
#define _alloc_check(fn) \
do { \
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
else _cuda_err = cudaSuccess; \
} while(false)
#endif

View File

@ -1,191 +0,0 @@
#include "block_reduce.h"
#include "cuda_util.h"
#include "kernels.h"
#include "ls_cub.cuh"
ls::cub::CachingDeviceAllocator g_allocator(true);
template <typename T>
__global__ void ls_cross_entropy_fw_kernel(
const T *__restrict__ inputs, const int *__restrict__ targets,
float *__restrict__ outputs, float *__restrict__ nll_loss_outputs,
const int padding_idx, const float epsilon, const int vocab_size) {
/* step1: compute each thread's max_logit and sum_exp_logit, store in
* max_input, sum_exp_logit */
const int block_start = blockIdx.x * vocab_size;
const int left_idx = block_start + threadIdx.x;
const int right_idx = (blockIdx.x + 1) * vocab_size;
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
int target_tid = targets[blockIdx.x];
if (target_tid == padding_idx) {
if (threadIdx.x == 0) {
nll_loss_outputs[blockIdx.x] = 0.f;
outputs[blockIdx.x] = 0.f;
}
return;
}
for (int i = left_idx; i < right_idx; i += blockDim.x) {
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
}
blockReduce<ReduceType::kMax, 1>(max_input);
__shared__ float s_max_input;
if (threadIdx.x == 0) {
s_max_input = max_input[0];
}
__syncthreads();
for (int i = left_idx; i < right_idx; i += blockDim.x) {
float logit = static_cast<float>(inputs[i]) - s_max_input;
sum_logits[0] += logit;
sum_logits[1] += expf(logit);
}
blockReduce<ReduceType::kSum, 2>(sum_logits);
__shared__ float s_sum_logit;
__shared__ float s_sum_exp;
if (threadIdx.x == 0) {
s_sum_logit = sum_logits[0];
s_sum_exp = sum_logits[1];
}
__syncthreads();
float eps_i = epsilon / (vocab_size - 1);
if (threadIdx.x == 0) {
// neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max)
float nll_loss = logf(s_sum_exp) -
static_cast<float>(inputs[block_start + target_tid]) +
s_max_input;
nll_loss_outputs[blockIdx.x] = nll_loss;
float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit;
outputs[blockIdx.x] =
(1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss;
}
}
template <typename T>
__global__ void ls_cross_entropy_bw_kernel(
const float *__restrict__ grad_outputs, const T *__restrict__ inputs,
const int *__restrict__ targets, T *__restrict__ grad_inputs,
const int padding_idx, const float epsilon, const int vocab_size) {
/* step1: compute each thread's max_logit and sum_exp_logit, store in
* max_input, sum_exp_logit */
const int block_start = blockIdx.x * vocab_size;
const int left_idx = block_start + threadIdx.x;
const int right_idx = (blockIdx.x + 1) * vocab_size;
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
float sum_logits[1] = {0.f};
const float grad_out = static_cast<float>(grad_outputs[0]);
int target_tid = targets[blockIdx.x];
if (target_tid == padding_idx) {
for (int i = left_idx; i < right_idx; i += blockDim.x) {
grad_inputs[i] = 0.f;
}
return;
}
for (int i = left_idx; i < right_idx; i += blockDim.x) {
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
}
blockReduce<ReduceType::kMax, 1>(max_input);
__shared__ float s_max_input;
if (threadIdx.x == 0) {
s_max_input = max_input[0];
}
__syncthreads();
for (int i = left_idx; i < right_idx; i += blockDim.x) {
float logit = static_cast<float>(inputs[i]) - s_max_input;
sum_logits[0] += expf(logit);
}
blockReduce<ReduceType::kSum, 1>(sum_logits);
__shared__ float s_sum_exp;
if (threadIdx.x == 0) {
s_sum_exp = sum_logits[0];
}
__syncthreads();
float eps_i = epsilon / (vocab_size - 1);
float nll_weight = 1.0 - epsilon - eps_i;
for (int i = left_idx; i < right_idx; i += blockDim.x) {
float prob = expf(static_cast<float>(inputs[i]) - s_max_input) / s_sum_exp;
float grad = 0;
grad += (vocab_size * prob - 1) * eps_i;
grad += prob * nll_weight;
if ((i - block_start) == target_tid) {
grad -= nll_weight;
}
grad_inputs[i] = grad_out * grad;
}
}
template <typename T>
void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
float *outputs_ptr, float *nll_loss_ptr,
float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size,
const int seq_len, const int vocab_size,
cudaStream_t stream) {
int grid_dim = batch_size * seq_len;
float *nll_loss_buffer = loss_buffer + grid_dim;
ls_cross_entropy_fw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx,
epsilon, vocab_size);
int num_items = grid_dim;
void *d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
loss_buffer, outputs_ptr,
num_items, stream));
CHECK_GPU_ERROR(
g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes));
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
loss_buffer, outputs_ptr,
num_items, stream));
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
nll_loss_buffer, nll_loss_ptr,
num_items, stream));
CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage));
}
template void launch_cross_entropy_fw<float>(
const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template void launch_cross_entropy_fw<__half>(
const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template <typename T>
void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
const int *targets_ptr, T *grad_inputs_ptr,
const int padding_idx, const float epsilon,
const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream) {
int grid_dim = batch_size * seq_len;
ls_cross_entropy_bw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx,
epsilon, vocab_size);
}
template void launch_cross_entropy_bw<float>(
const float *grad_outputs_ptr, const float *inputs_ptr,
const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template void launch_cross_entropy_bw<__half>(
const float *grad_outputs_ptr, const __half *inputs_ptr,
const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);

View File

@ -1,88 +0,0 @@
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/
#include "cublas_wrappers.h"
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const float *A,
const float *B, float *C, cublasGemmAlgo_t algo) {
cublasStatus_t status =
cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha,
(const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k,
(const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n,
(const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const __half *A,
const __half *B, __half *C, cublasGemmAlgo_t algo) {
cublasStatus_t status = cublasGemmEx(
handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A,
CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F,
(transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C,
CUDA_R_16F, m, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
const float *alpha, const float *beta,
const float *A, const float *B, float *C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch, cublasGemmAlgo_t algo) {
cublasStatus_t status = cublasGemmStridedBatchedEx(
handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F,
(op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F,
(op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C,
batch, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, "
"error: %d) \n",
batch, m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
const float *alpha, const float *beta,
const __half *A, const __half *B, __half *C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch, cublasGemmAlgo_t algo) {
cublasStatus_t status = cublasGemmStridedBatchedEx(
handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F,
(op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F,
(op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C,
batch, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}

View File

@ -1,169 +0,0 @@
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/transform_reduce.h>
#include "cuda_util.h"
/* GPU function guard */
std::string _cudaGetErrorString(cudaError_t error) {
return cudaGetErrorString(error);
}
std::string _cudaGetErrorString(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "CUBLAS_UNKNOW";
}
template <typename T>
void check_gpu_error(T result, char const *const func, const char *const file,
int const line) {
if (result) {
throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" +
std::to_string(line) +
"): " + (_cudaGetErrorString(result)) + "\n");
}
}
template void check_gpu_error<cudaError_t>(cudaError_t result,
char const *const func,
const char *const file,
int const line);
template void check_gpu_error<cublasStatus_t>(cublasStatus_t result,
char const *const func,
const char *const file,
int const line);
template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele) {
std::cout << outn << ": ";
std::vector<T> hout(num_output_ele, (T)0);
cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T),
cudaMemcpyDeviceToHost);
for (int i = 0; i < num_output_ele; i++) {
std::cout << hout[i] << ", ";
}
std::cout << std::endl;
}
template <>
void print_vec<__half>(const __half *outv, std::string outn,
int num_output_ele) {
std::cout << outn << ": ";
std::vector<__half> hout(num_output_ele, (__half)0.f);
cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half),
cudaMemcpyDeviceToHost);
for (int i = 0; i < num_output_ele; i++) {
std::cout << __half2float(hout[i]) << ", ";
}
std::cout << std::endl;
}
template void print_vec<float>(const float *outv, std::string outn,
int num_output_ele);
template void print_vec<int>(const int *outv, std::string outn,
int num_output_ele);
template void print_vec<__half>(const __half *outv, std::string outn,
int num_output_ele);
template <typename T>
T *cuda_malloc(size_t ele_num) {
size_t byte_size = ele_num * sizeof(T);
T *pdata = nullptr;
CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size));
return pdata;
}
template float *cuda_malloc<float>(size_t ele_num);
template __half *cuda_malloc<__half>(size_t ele_num);
template uint8_t *cuda_malloc<uint8_t>(size_t ele_num);
void cuda_free(void *pdata) {
if (pdata != nullptr) {
cudaFree(pdata);
}
}
template <typename T>
struct _isnan {
__device__ bool operator()(T a) const { return isnan(a); }
};
template <>
struct _isnan<__half> {
__device__ bool operator()(const __half a) const { return __hisnan(a); }
};
template <typename T>
struct _isinf {
__device__ bool operator()(T a) const { return isinf(a); }
};
template <>
struct _isinf<__half> {
__device__ bool operator()(const __half a) const { return __hisinf(a); }
};
template <typename T>
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
std::string file, int line, cudaStream_t stream) {
// check_nan_inf = 0 for checking nan
// check_nan_inf = 1 for checking inf
bool res = false;
std::string msg = file + "(" + std::to_string(line) + "): ";
if (check_nan_inf) {
msg += "nan.";
res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
data_ptr + dsize, _isnan<T>(), false,
thrust::logical_or<bool>());
} else {
msg += "inf.";
res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
data_ptr + dsize, _isinf<T>(), false,
thrust::logical_or<bool>());
}
if (res) {
throw std::runtime_error(msg);
}
std::cout << msg << " [check pass]." << std::endl;
}
template void check_nan_inf<float>(const float *data_ptr, int dsize,
bool check_nan_inf, std::string file,
int line, cudaStream_t stream);
template void check_nan_inf<__half>(const __half *data_ptr, int dsize,
bool check_nan_inf, std::string file,
int line, cudaStream_t stream);

File diff suppressed because it is too large Load Diff

View File

@ -1,232 +0,0 @@
#include <cooperative_groups.h>
#include "kernels.h"
namespace cg = cooperative_groups;
/**
@brief: fuse_transpose_bias
Calculate the sum of elements in each column of the matrix.
@thread
gridDim.x = ceil(cols / WARP_SIZE)
blockDim.x = WARP_SIZE
blockDim.y = WARP_SIZE
@param
inp: [rows, cols]
out: [cols]
rows: the number of rows in the matrix
cols: the number of cols in the matrix
*/
template <typename T>
__global__ void column_sum_reduce(const T *__restrict__ inp,
T *__restrict__ out, int rows, int cols) {
__shared__ float tile[WARP_SIZE][WARP_SIZE];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
int y_stride = cols * WARP_SIZE;
float localSum = 0;
// Loop across matrix row
// TODO: optimize to log complexity
if (idx < cols) {
int offset = flat_2dim(threadIdx.y, idx, cols);
for (int r = threadIdx.y; r < rows; r += WARP_SIZE) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
// The sum of a row in tile is equal to the sum of a col in original matrix
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
// The change of threadIdx.x is continuous
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();
// Calculate the sum of a row in tile
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE);
if (pos < cols) out[pos] = sum;
}
}
// [r, c] -> [c]
template <>
void launch_fuse_transpose_bias_kernel<float>(const float *inp, float *out,
int rows, int cols,
cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<float>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out,
int rows, int cols,
cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<__half>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
/**
@brief: fused_add2
Add two matrix inp1 and inp2 to out.
@thread
gridDim.x = batch_size * seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
inp1: [batch_size, seq_len, hidden_dim]
inp2: [batch_size, seq_len, hidden_dim]
out: [batch_size, seq_len, hidden_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
*/
template <typename T>
__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2,
int hidden_dim);
template <>
__global__ void fused_add2_kernel<float>(float *out, const float *inp1,
const float *inp2, int hidden_dim) {
int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1;
float4 vinp2;
float4 val;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i];
val.x = vinp1.x + vinp2.x;
val.y = vinp1.y + vinp2.y;
val.z = vinp1.z + vinp2.z;
val.w = vinp1.w + vinp2.w;
out_4[offset + i] = val;
}
}
template <>
__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1,
const __half *inp2, int hidden_dim) {
int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1;
float4 vinp2;
float4 val;
__half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1);
__half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2);
__half2 *h2_val = reinterpret_cast<__half2 *>(&val);
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i];
h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]);
h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]);
h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]);
h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]);
out_4[offset + i] = val;
}
}
//[b, s, h] -> [b, s, h]
template <>
void launch_fused_add2<float>(float *out, const float *inp1, const float *inp2,
int batch_size, int seq_len, int hidden_dim,
cudaStream_t &stream) {
hidden_dim >>= 2;
dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim);
}
template <>
void launch_fused_add2<__half>(__half *out, const __half *inp1,
const __half *inp2, int batch_size, int seq_len,
int hidden_dim, cudaStream_t &stream) {
hidden_dim >>= 3;
dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim);
}
template <typename T>
__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output,
int sz0, int sz2, int sz1_1, int sz1_2) {
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x);
if (idx >= nele) {
return;
}
float4 *dst_ptr = (float4 *)output + idx;
int idx2 = idx % sz2;
idx = idx / sz2;
int idx1 = idx % (sz1_1 + sz1_2);
int idx0 = idx / (sz1_1 + sz1_2);
float4 *src_ptr = nullptr;
int sz1 = 0;
if (idx1 < sz1_1) {
sz1 = sz1_1;
src_ptr = (float4 *)inp1;
} else {
idx1 -= sz1_1;
sz1 = sz1_2;
src_ptr = (float4 *)inp2;
}
src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2);
dst_ptr[0] = src_ptr[0];
}
template <>
void launch_concat3_dim1<float>(const float *inp1, const float *inp2,
float *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) {
sz2 >>= 2;
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
}
template <>
void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2,
__half *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) {
sz2 >>= 3;
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
}

View File

@ -1,36 +0,0 @@
#pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <iostream>
#include <string>
#include "cuda_util.h"
class Context {
public:
Context() : _stream(nullptr) {
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
}
virtual ~Context() {}
static Context &Instance() {
static Context _ctx;
return _ctx;
}
void set_stream(cudaStream_t stream) {
_stream = stream;
CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream));
}
cudaStream_t get_stream() { return _stream; }
cublasHandle_t get_cublashandle() { return _cublasHandle; }
private:
cudaStream_t _stream;
cublasHandle_t _cublasHandle;
};

View File

@ -1,46 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <type_traits>
#include "cuda_util.h"
template <typename T>
class CrossEntropyLayer {
public:
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
virtual ~CrossEntropyLayer();
void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
float *nll_loss_ptr);
void Backward(const float *grad_outputs_ptr, const T *inputs_ptr,
const int *targets_ptr, T *grad_inputs_ptr);
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
private:
void allocate_mem_buffer() {
// allocate local gpu memory
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);
}
void free_mem_buffer() {
// free local gpu memory
cuda_free(_loss_buffer);
}
const int _padding_idx;
const float _epsilon;
const int _max_batch_tokens;
size_t _batch_size;
size_t _seq_len;
size_t _vocab_size;
float *_loss_buffer;
};

View File

@ -1,41 +0,0 @@
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/
#pragma once
#include <assert.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <stdio.h>
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const float *A,
const float *B, float *C,
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const __half *A,
const __half *B, __half *C,
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
const float *alpha, const float *beta,
const float *A, const float *B, float *C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch,
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
int cublas_strided_batched_gemm(
cublasHandle_t handle, int m, int n, int k, const float *alpha,
const float *beta, const __half *A, const __half *B, __half *C,
cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B,
int stride_C, int batch,
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);

View File

@ -1,34 +0,0 @@
#pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <math_constants.h>
#include <chrono>
#include <fstream>
#include <iostream>
#include <string>
#include <type_traits>
#include <vector>
template <typename T>
void check_gpu_error(T result, char const *const func, const char *const file,
int const line);
#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__)
template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele);
template <typename T>
T *cuda_malloc(size_t ele_num);
void cuda_free(void *pdata);
template <typename T>
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
std::string file, int line, cudaStream_t stream);
#define CHECK_NAN_INF(ptr, size, stream) \
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))

View File

@ -1,96 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <string>
#include "kernels.h"
template <typename T>
class Dropout {
public:
struct Config {
float ratio;
bool training;
Config(float r) : ratio(r), training(true) {}
float RATIO() const { return training ? ratio : 0.0; }
};
Dropout(const Config &config, size_t max_ele_num)
: _config(config), _mask(nullptr) {
_mask = cuda_malloc<uint8_t>(max_ele_num);
}
virtual ~Dropout() { cuda_free(_mask); }
// after attention softmax
void dropout(T *output, const T *input, int count, cudaStream_t stream,
bool bwd = false) {
launch_ls_dropout<T>(output, input, _mask, count, _config.RATIO(), stream,
bwd);
}
void d_dropout(T *d_inp_out, int count, cudaStream_t stream) {
launch_ls_dropout<T>(d_inp_out, d_inp_out, _mask, count, _config.RATIO(),
stream, true);
}
// transformer layer's postprocessing dropout, after attn or ffn module,
// before residual add.
void bias_dropout_residual(T *output, const T *input, const T *residual,
const T *bias, int rows, int cols,
cudaStream_t stream) {
launch_ls_dropout_res_bias<T>(output, input, _mask, bias, residual,
rows * cols, cols, _config.RATIO(), stream);
}
void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output,
int rows, int cols, cudaStream_t stream) {
launch_ls_dropout_bias_bwd<T>(d_input, d_bias, d_output, _mask, rows, cols,
_config.RATIO(), stream);
}
// dropout inside ffn.
void bias_act_dropout(T *output, const T *input, const T *bias, int rows,
int cols, std::string activation_fn,
cudaStream_t stream) {
if (activation_fn == "relu") {
launch_ls_dropout_act_bias<ActivationType::kRelu, T>(
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
stream);
} else if (activation_fn == "gelu") {
launch_ls_dropout_act_bias<ActivationType::kGelu, T>(
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
stream);
} else {
throw std::runtime_error("not supported activation: " + activation_fn);
}
}
void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input,
const T *bias, int rows, int cols,
std::string activation_fn, cudaStream_t stream) {
if (activation_fn == "relu") {
launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, T>(
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
_config.RATIO(), stream);
} else if (activation_fn == "gelu") {
launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, T>(
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
_config.RATIO(), stream);
} else {
throw std::runtime_error("not supported activation: " + activation_fn);
}
}
bool HasDropout() const { return _config.RATIO() > 0.0; }
void SetTrainingMode(bool training) { _config.training = training; }
private:
uint8_t *_mask;
Config _config;
};

View File

@ -1,69 +0,0 @@
#pragma once
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <array>
#include "cublas_wrappers.h"
#include "kernels.h"
template <typename T>
class FeedForward {
public:
struct Config {
int outputSize;
int inputSize;
std::array<int, 3> gemm_algos;
Config(int outputs, int inputs)
: outputSize(outputs),
inputSize(inputs),
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
};
FeedForward(Config config) : config_(config) {}
~FeedForward() {}
void Forward(int bsz, const T *input_ptr, const T *weights, T *out,
cublasHandle_t &_cublasHandle) {
float alpha = T(1.);
float beta = T(0.);
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize,
bsz, config_.inputSize, &alpha, &beta, weights, input_ptr,
out, cublasGemmAlgo_t(config_.gemm_algos[0]));
}
void Backward(int bsz, const T *out_grad, const T *input_ptr,
const T *weights, T *weights_grad, T *bias_grad,
cublasHandle_t &_cublasHandle, cudaStream_t &stream,
T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr,
bool compute_bias = true) {
float alpha = (T)1.0, beta = (T)0.0;
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize,
config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad,
weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1]));
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize,
bsz, config_.outputSize, &alpha, &beta, weights, out_grad,
inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2]));
if (compute_bias) {
launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz,
config_.outputSize, stream);
}
}
void reset_size(int outputSize, int inputSize) {
config_.outputSize = outputSize;
config_.inputSize = inputSize;
}
private:
Config config_;
};

View File

@ -1,275 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <curand_kernel.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdexcept>
#define MAX_THREADS 1024
#define WARP_SIZE 32
enum class ActivationType { kRelu, kGelu };
void launch_curand_init(int total_count, int dim, cudaStream_t stream);
template <typename T>
void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp,
const T *scale, const T *bias, int batch_size,
int hidden_dim, cudaStream_t stream);
template <typename T>
void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
const T *residual_grad, const T *inp_or_out, const T *gamma,
const T *betta, const T *vars, const T *means, int batch,
int hidden_dim, cudaStream_t stream[2]);
template <typename T>
void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads,
int from_len, int to_len, bool mask_future,
cudaStream_t stream);
template <typename T>
void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows,
int softmax_len, cudaStream_t stream);
// [b, s, h] -> [b, nh, s, ad]
template <typename T>
void launch_transform_0213(T *output, const T *vals, int batch_size,
int seq_length, int hidden_dim, int nhead,
cudaStream_t stream);
// [b, s, 3, h] -> [3, b, nh, s, ad]
template <typename T>
void launch_bias_add_transform_20314(T *output, const T *input, const T *bias,
int dim_0, int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream);
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
template <typename T>
void launch_transform4d_0213(T *output, const T *vals, int batch_size,
int seq_len, int hidden_dim, int nhead,
int trans_count, cudaStream_t stream);
template <typename T>
void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count,
float ratio, cudaStream_t stream, bool backward = false);
template <typename T>
void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask,
const T *bias, const T *residual,
int total_count, int dim, float ratio,
cudaStream_t stream);
template <ActivationType, typename T>
void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask,
const T *bias, int total_count, int dim,
float ratio, cudaStream_t stream);
template <typename T>
void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template <ActivationType act_type, typename T>
void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
const T *bias, const T *out_grad,
const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template <typename T>
void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols,
cudaStream_t stream);
void launch_param_update(const float *input, __half *output, int size,
cudaStream_t stream);
template <typename T>
void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0,
int sz2, int sz1_1, int sz1_2, cudaStream_t stream);
template <typename T>
void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size,
int seq_len, int hidden_size, cudaStream_t &stream);
template <typename T>
void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
float *outputs_ptr, float *nll_loss_ptr,
float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size,
const int seq_len, const int vocab_size,
cudaStream_t stream);
template <typename T>
void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
const int *targets_ptr, T *grad_inputs_ptr,
const int padding_idx, const float epsilon,
const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template <typename T>
void launch_lookup_scale_pos_dropout(
T *output, const int *input, const T *embeddings, const T *pos_embeddings,
uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim,
int padding_idx, float dropout_ratio, int step, cudaStream_t &stream);
template <typename T>
void launch_d_lookup_scale_pos_dropout(
T *grad_embeddings, const T *grad_output, const int *input,
const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim,
int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream);
/* Convert 2-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) {
return id1 * dim2 + id2;
}
/* Convert 3-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3,
int dim2, int dim3) {
return id1 * dim2 * dim3 + id2 * dim3 + id3;
}
/* Convert 4-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3,
int id4, int dim2, int dim3,
int dim4) {
// return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4;
int res = id4;
int ld = dim4;
res += id3 * ld;
ld *= dim3;
res += id2 * ld;
ld *= dim2;
res += id1 * ld;
return res;
}
/* Convert 5-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3,
int id4, int id5, int dim2,
int dim3, int dim4,
int dim5) {
// return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) +
// id4*dim5 + dim5;
int res = id5;
int ld = dim5;
res += id4 * ld;
ld *= dim4;
res += id3 * ld;
ld *= dim3;
res += id2 * ld;
ld *= dim2;
res += id1 * ld;
return res;
}
/* Convert 6-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3,
int id4, int id5, int id6,
int dim2, int dim3, int dim4,
int dim5, int dim6) {
// return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) +
// id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6;
int res = id6;
int ld = dim6;
res += id5 * ld;
ld *= dim5;
res += id4 * ld;
ld *= dim4;
res += id3 * ld;
ld *= dim3;
res += id2 * ld;
ld *= dim2;
res += id1 * ld;
return res;
}
/* Convert vector index to 6-dim tensor index */
__forceinline__ __host__ __device__ void decompose_6dim(
int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0,
int *id1, int *id2, int *id3, int *id4, int *id5) {
*id5 = src % dim5;
src /= dim5;
*id4 = src % dim4;
src /= dim4;
*id3 = src % dim3;
src /= dim3;
*id2 = src % dim2;
src /= dim2;
*id1 = src % dim1;
*id0 = src / dim1;
}
/* Convert vector index to 5-dim tensor index */
__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1,
int dim2, int dim3,
int dim4, int *id0,
int *id1, int *id2,
int *id3, int *id4) {
*id4 = src % dim4;
src /= dim4;
*id3 = src % dim3;
src /= dim3;
*id2 = src % dim2;
src /= dim2;
*id1 = src % dim1;
*id0 = src / dim1;
}
/* Convert vector index to 4-dim tensor index */
__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1,
int dim2, int dim3,
int *id0, int *id1,
int *id2, int *id3) {
*id3 = src % dim3;
src /= dim3;
*id2 = src % dim2;
src /= dim2;
*id1 = src % dim1;
*id0 = src / dim1;
}
/* Convert vector index to 3-dim tensor index */
__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1,
int dim2, int *id0,
int *id1, int *id2) {
*id2 = src % dim2;
src /= dim2;
*id1 = src % dim1;
*id0 = src / dim1;
}
/* Convert vector index to 2-dim tensor index */
__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1,
int *id0, int *id1) {
*id1 = src % dim1;
*id0 = src / dim1;
}

View File

@ -1,12 +0,0 @@
// copied from https://github.com/dmlc/dgl/pull/2758
#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_
#define DGL_ARRAY_CUDA_DGL_CUB_CUH_
#define CUB_NS_PREFIX namespace ls {
#define CUB_NS_POSTFIX }
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
#endif

View File

@ -1,65 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <fstream>
#include "kernels.h"
using namespace std;
template <typename T>
class Normalize_Layer {
public:
struct Config {
uint32_t hidden_dim;
bool use_mean;
Config(uint32_t hidden_dim, bool use_mean = false)
: hidden_dim(hidden_dim), use_mean(use_mean) {}
};
Normalize_Layer(Config config, size_t max_rows)
: config_(config), vars_(nullptr), means_(nullptr) {
vars_ = cuda_malloc<T>(max_rows);
if (config_.use_mean) {
means_ = cuda_malloc<T>(max_rows);
}
}
~Normalize_Layer() {
cuda_free(vars_);
cuda_free(means_);
}
void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta,
int batch_size, cudaStream_t stream) {
launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size,
config_.hidden_dim, stream);
}
/*
residual_grad, inp_or_out, betta should be treated carefully.
inp_or_out = input if use_mean else output
residual_grad, betta can be nullptr.
residual_grad will be added to dinp if it is not nullptr
which is useful in transformer layer when pre-ln
betta are only used to compute xhat,
(use_mean == false) ^ (betta == nullptr) should be true
*/
void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
const T *residual_grad, const T *inp_or_out, const T *gamma,
const T *betta, int batch_size, cudaStream_t stream[2]) {
launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad,
inp_or_out, gamma, betta, vars_, means_, batch_size,
config_.hidden_dim, stream);
}
inline bool use_mean() const { return config_.use_mean; }
private:
Config config_;
T *vars_;
T *means_;
};

View File

@ -1,42 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <fstream>
#include "kernels.h"
using namespace std;
template <typename T>
class Softmax {
public:
struct Config {
size_t nhead;
Config(size_t nhead) : nhead(nhead) {}
};
Softmax(Config config) : config_(config) {}
~Softmax() {}
void Forward(T *vals, const T *attn_mask, int batch_size, int from_len,
int to_len, cudaStream_t &stream, bool mask_future = true) {
launch_attn_softmax<T>(vals, attn_mask, batch_size, config_.nhead, from_len,
to_len, mask_future, stream);
}
void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len,
int to_len, cudaStream_t stream) {
launch_attn_softmax_bw<T>(out_grad, soft_out,
batch_size * config_.nhead * from_len, to_len,
stream);
}
void reset_size(size_t nhead) { config_.nhead = nhead; }
private:
Config config_;
};

View File

@ -1,100 +0,0 @@
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <array>
#include "cublas_wrappers.h"
template <typename T>
class StridedBatchGemm {
public:
struct Config {
int m;
int n;
int k;
float alpha;
float beta;
cublasOperation_t op_A;
cublasOperation_t op_B;
std::array<int, 3> gemm_algos;
Config(float param_alpha, float param_beta, cublasOperation_t opA,
cublasOperation_t opB)
: alpha(param_alpha),
beta(param_beta),
op_A(opA),
op_B(opB),
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
void SetConfig(int mm, int nn, int kk) {
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config &config) : _config(config) {}
virtual ~StridedBatchGemm() {}
void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b,
cublasHandle_t handle) {
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(
handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta,
_buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a,
stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0]));
}
void Backward(int bsz, const T *d_output, const T *_buffer_a,
const T *_buffer_b, cublasHandle_t handle,
T *inpGradA = nullptr, T *inpGradB = nullptr) {
int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
int stride_a = mb * _config.n;
int stride_b = _config.n * kb;
int stride_c = _config.m * _config.k;
// B need to transpose.
cublasOperation_t op_b =
(_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
// Calculate d_A.
cublas_strided_batched_gemm(
handle, mb, kb, _config.n, &_config.alpha, &_config.beta,
(_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
(_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA,
CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz,
cublasGemmAlgo_t(_config.gemm_algos[1]));
// A need to transpose.
cublasOperation_t op_a =
(_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
stride_a = _config.m * _config.k;
stride_b = _config.m * _config.n;
stride_c = _config.n * _config.k;
// Calculate d_B.
cublas_strided_batched_gemm(
handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta,
_buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b,
stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2]));
}
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
};

File diff suppressed because it is too large Load Diff

View File

@ -1,365 +0,0 @@
#include <cooperative_groups.h>
#include <math.h>
#include <cub/block/block_load.cuh>
#include <cub/cub.cuh>
#include "block_reduce.h"
#include "kernels.h"
namespace cg = cooperative_groups;
const float EPSILON = 1e-8f;
/**
@brief: softmax_kernel
Softmax forward kernel for
enc-self-attn, dec-self-attn, encdec-attn
@thread
gridDim.x = dynamic
gridDim.y = batch_size
gridDim.z = nhead
blockDim.x = from_len
@param
inp: [batch_size, nhead, from_len, to_len], softmax input.
attn_mask: [batch_size, to_len], padding tokens are -inf,
non padding tokens are 0.
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template <typename T, int block_dim, int ele_per_thread>
__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
int to_len, bool mask_future) {
int batch_id = blockIdx.y;
int head_id = blockIdx.z;
const int nhead = gridDim.z;
const int token_per_reduce = 1;
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_dim, ele_per_thread,
cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
T mval[ele_per_thread];
if (attn_mask) {
attn_mask += batch_id * to_len;
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
}
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
token_id += gridDim.x * token_per_reduce) {
T inp_val[token_per_reduce][ele_per_thread];
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
REDUCE_FLOAT_INF_NEG);
}
/* step 1. compute max */
// thread local max
float val[token_per_reduce][ele_per_thread];
float l_max[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_max[i] = REDUCE_FLOAT_INF_NEG;
for (int j = 0; j < ele_per_thread; j++) {
if (attn_mask) {
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
} else {
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
val[i][j] = REDUCE_FLOAT_INF_NEG;
} else {
val[i][j] = (float)inp_val[i][j];
}
}
l_max[i] = fmaxf(l_max[i], val[i][j]);
}
}
// block reduce max
blockReduce<ReduceType::kMax, token_per_reduce>(l_max);
// write shared
__shared__ float s_max[token_per_reduce];
if (threadIdx.x == 0) {
for (int i = 0; i < token_per_reduce; i++) {
s_max[i] = l_max[i];
}
}
__syncthreads();
/* step 2. compute sum */
// thread local sum
float l_sum[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_sum[i] = 0.f;
for (int j = 0; j < ele_per_thread; j++) {
val[i][j] = __expf(val[i][j] - s_max[i]);
l_sum[i] += val[i][j];
}
}
// block reduce sum
blockReduce<ReduceType::kSum, token_per_reduce>(l_sum);
// write shared
__shared__ float s_sum[token_per_reduce];
if (threadIdx.x == 0) {
for (int i = 0; i < token_per_reduce; i++) {
s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
}
}
__syncthreads();
/* step 3. compute final result */
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
for (int j = 0; j < ele_per_thread; j++) {
inp_val[i][j] = (T)(val[i][j] * s_sum[i]);
}
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len);
}
} // blockIdx.x
}
template <typename T, int block_dim, int ele_per_thread>
__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
int to_len, bool mask_future) {
int batch_id = blockIdx.y;
int head_id = blockIdx.z;
const int nhead = gridDim.z;
const int token_per_reduce = 1;
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_dim, ele_per_thread,
cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
T mval[ele_per_thread];
if (attn_mask) {
attn_mask += batch_id * to_len;
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
}
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
token_id += gridDim.x * token_per_reduce) {
T inp_val[token_per_reduce][ele_per_thread];
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
REDUCE_FLOAT_INF_NEG);
}
/* step 1. compute max */
// thread local max
float val[token_per_reduce][ele_per_thread];
float l_max[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_max[i] = REDUCE_FLOAT_INF_NEG;
for (int j = 0; j < ele_per_thread; j++) {
if (attn_mask) {
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
} else {
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
val[i][j] = REDUCE_FLOAT_INF_NEG;
} else {
val[i][j] = (float)inp_val[i][j];
}
}
l_max[i] = fmaxf(l_max[i], val[i][j]);
}
}
// warp reduce max
warpReduce<ReduceType::kMax, token_per_reduce>(l_max);
/* step 2. compute sum */
// thread local sum
float l_sum[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_sum[i] = 0.f;
for (int j = 0; j < ele_per_thread; j++) {
val[i][j] = __expf(val[i][j] - l_max[i]);
l_sum[i] += val[i][j];
}
}
// warp reduce sum
warpReduce<ReduceType::kSum, token_per_reduce>(l_sum);
/* step 3. compute final result */
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
for (int j = 0; j < ele_per_thread; j++) {
inp_val[i][j] = (T)(val[i][j] * l_sum[i]);
}
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len);
}
} // blockIdx.x
}
/*
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template <>
void launch_attn_softmax<float>(float *inp, const float *attn_mask,
int batch_size, int nhead, int from_len,
int to_len, bool mask_future,
cudaStream_t stream) {
dim3 grid_dim(1, batch_size, nhead);
if (to_len <= 32) {
ker_attn_softmax_lt32<float, 32, 1><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 64) {
ker_attn_softmax_lt32<float, 32, 2><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 128) {
grid_dim.x = 16;
ker_attn_softmax<float, 64, 2><<<grid_dim, 64, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 256) {
grid_dim.x = 32;
ker_attn_softmax<float, 128, 2><<<grid_dim, 128, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 512) {
grid_dim.x = 64;
ker_attn_softmax<float, 256, 2><<<grid_dim, 256, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else {
throw std::runtime_error(
"Sequence length greater than 512 is currently not supported");
}
}
template <>
void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask,
int batch_size, int nhead, int from_len,
int to_len, bool mask_future,
cudaStream_t stream) {
dim3 grid_dim(1, batch_size, nhead);
if (to_len <= 32) {
ker_attn_softmax_lt32<__half, 32, 1><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 64) {
ker_attn_softmax_lt32<__half, 32, 2><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 128) {
grid_dim.x = 8;
ker_attn_softmax<__half, 64, 2><<<grid_dim, 64, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 256) {
grid_dim.x = 16;
ker_attn_softmax<__half, 128, 2><<<grid_dim, 128, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 512) {
grid_dim.x = 32;
ker_attn_softmax<__half, 256, 2><<<grid_dim, 256, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else {
throw std::runtime_error(
"Sequence length greater than 512 is currently not supported");
}
}
/**
@brief: ker_attn_softmax_bw
Softmax backward in self attention.
@thread
gridDim.x = batch_size * nhead * seq_len / warps_per_block
blockDim.x = WARP_SIZE
blockDim.y = warps_per_block
@param
grad: [batch_size, nhead, seq_len, seq_len], output grad.
output: [batch_size, nhead, seq_len, seq_len], output of softmax forward.
*/
template <typename T, int ITERATIONS>
__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
inp += offset;
T grad_reg[ITERATIONS];
T inp_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
inp_reg[i] = inp[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)inp_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum));
}
}
template <typename T>
void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows,
int softmax_len, cudaStream_t stream) {
const int warps_per_block = 4;
// rows = batch_size * nhead * from_len
dim3 grid_dim(rows / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
if (softmax_len <= 32)
ker_attn_softmax_bw<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 64)
ker_attn_softmax_bw<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 128)
ker_attn_softmax_bw<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 256)
ker_attn_softmax_bw<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 384)
ker_attn_softmax_bw<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 512)
ker_attn_softmax_bw<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 768)
ker_attn_softmax_bw<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 1024)
ker_attn_softmax_bw<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 2048)
ker_attn_softmax_bw<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else
throw std::runtime_error(
std::string(
"Special sequence length found in softmax backward, seq_len: ") +
std::to_string(softmax_len));
}
template void launch_attn_softmax_bw<__half>(__half *out_grad,
const __half *soft_inp, int rows,
int softmax_len,
cudaStream_t stream);
template void launch_attn_softmax_bw<float>(float *out_grad,
const float *soft_inp, int rows,
int softmax_len,
cudaStream_t stream);

View File

@ -1,314 +0,0 @@
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include "kernels.h"
using namespace cub;
/**
@brief: transform_0213
Split the attention heads and reshape input
during backward progress of encoder self-attention
@thread
gridDim.x = batch_size
gridDim.y = seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
input: [batch_size, seq_len, hidden_dim]
output: [batch_size, nhead, seq_len, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
*/
template <typename T>
__global__ void transform_0213(T *output, const T *input, int hidden_dim,
int head_dim);
template <>
__global__ void transform_0213<float>(float *output, const float *input,
int hidden_dim, int head_dim) {
int batch_id = blockIdx.x;
int token_id = blockIdx.y;
int seq_len = gridDim.y;
int nhead = hidden_dim / head_dim;
// [b, s, h]
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
// [b, nh, s, ad]
int trg_offset =
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vinput4;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinput4 = input4[src_offset + i];
int head_id = i / head_dim;
int dim_id = i % head_dim;
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
res4[trg_offset + cur_trg_offset] = vinput4;
}
}
template <>
__global__ void transform_0213<__half>(__half *output, const __half *input,
int hidden_dim, int head_dim) {
int batch_id = blockIdx.x;
int token_id = blockIdx.y;
int seq_len = gridDim.y;
int nhead = hidden_dim / head_dim;
// [b, s, h]
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
// [b, nh, s, ad]
int trg_offset =
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vinput4;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinput4 = input4[src_offset + i];
int head_id = i / head_dim;
int dim_id = i % head_dim;
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
res4[trg_offset + cur_trg_offset] = vinput4;
}
}
// [b, s, h] -> [b, nh, s, ad]
template <>
void launch_transform_0213<float>(float *output, const float *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, cudaStream_t stream) {
hidden_dim >>= 2;
int head_dim = hidden_dim / nhead;
dim3 grid_dim(batch_size, seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
}
template <>
void launch_transform_0213<__half>(__half *output, const __half *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, cudaStream_t stream) {
hidden_dim >>= 3;
int head_dim = hidden_dim / nhead;
dim3 grid_dim(batch_size, seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
transform_0213<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
}
/**
@brief: bias_add_transform_20314
Add bias to input, transform from
[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4]
@thread
gridDim.x = dim_0
gridDim.y = dim_1
gridDim.z = dim_2
blockDim.x = min(dim_3 * dim_4, MAX_THREADS)
@param
input: [dim_0, dim_1, dim_2, dim_3, dim_4]
bias: [dim_2, dim_3, dim_4]
output: [dim_2, dim_0, dim_3, dim_1, dim_4]
*/
template <typename T>
__global__ void bias_add_transform_20314(T *output, const T *input,
const T *bias, int dim_3, int dim_4);
template <>
__global__ void bias_add_transform_20314<float>(float *output,
const float *input,
const float *bias, int dim_3,
int dim_4) {
int id0 = blockIdx.x;
int id1 = blockIdx.y;
int id2 = blockIdx.z;
int dim_0 = gridDim.x;
int dim_1 = gridDim.y;
int dim_2 = gridDim.z;
int dim_34 = dim_3 * dim_4;
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
int bias_offset = flat_2dim(id2, 0, dim_34);
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vqkv4;
float4 vbias4;
float4 vres4;
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
vqkv4 = qkv4[src_offset + i];
vbias4 = bias4[bias_offset + i];
vres4.x = vqkv4.x + vbias4.x;
vres4.y = vqkv4.y + vbias4.y;
vres4.z = vqkv4.z + vbias4.z;
vres4.w = vqkv4.w + vbias4.w;
int id3 = i / dim_4;
int id4 = i % dim_4;
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
res4[trg_offset + cur_trg_offset] = vres4;
}
}
template <>
__global__ void bias_add_transform_20314<__half>(__half *output,
const __half *input,
const __half *bias, int dim_3,
int dim_4) {
int id0 = blockIdx.x;
int id1 = blockIdx.y;
int id2 = blockIdx.z;
int dim_0 = gridDim.x;
int dim_1 = gridDim.y;
int dim_2 = gridDim.z;
int dim_34 = dim_3 * dim_4;
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
int bias_offset = flat_2dim(id2, 0, dim_34);
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vqkv4;
float4 vbias4;
float4 vres4;
__half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4);
__half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4);
__half2 *h2_res = reinterpret_cast<__half2 *>(&vres4);
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
vqkv4 = qkv4[src_offset + i];
vbias4 = bias4[bias_offset + i];
h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]);
h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]);
h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]);
h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]);
int id3 = i / dim_4;
int id4 = i % dim_4;
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
res4[trg_offset + cur_trg_offset] = vres4;
}
}
// [b, s, 3, h] -> [3, b, nh, s, ad]
template <>
void launch_bias_add_transform_20314<float>(float *output, const float *input,
const float *bias, int dim_0,
int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream) {
dim_4 >>= 2;
dim3 grid_dim(dim_0, dim_1, dim_2);
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
bias_add_transform_20314<float>
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
}
template <>
void launch_bias_add_transform_20314<__half>(__half *output,
const __half *input,
const __half *bias, int dim_0,
int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream) {
dim_4 >>= 3;
dim3 grid_dim(dim_0, dim_1, dim_2);
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
bias_add_transform_20314<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
}
/**
@brief: transform4d_0213
Reshape the input matrix to merge the heads
@thread
gridDim.x = (num_all + max_block_thread - 1) / max_block_thread
blockDim.x = max_block_thread
@param
input: [trans_count, batch_size, nhead, seq_len, head_dim]
output: [batch_size, seq_len, trans_count, nhead, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
trans_count: 1 or 3, the count of matrice need to be transformed
*/
template <typename T>
__global__ void transform4d_0213(T *output, const T *input, int batch_size,
int seq_len, int trans_count, int nhead,
int head_dim, int num_all) {
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset >= num_all) {
return;
}
int trans_id, batch_id, head_id, token_id, dim_id;
decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id,
&batch_id, &head_id, &token_id, &dim_id);
// [b, s, tc, nh, ad]
int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id,
seq_len, trans_count, nhead, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
res4[trg_offset] = input4[offset];
}
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
template <>
void launch_transform4d_0213<float>(float *output, const float *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, int trans_count,
cudaStream_t stream) {
hidden_dim >>= 2;
int head_dim = hidden_dim / nhead;
int num_all = batch_size * seq_len * trans_count * hidden_dim;
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
transform4d_0213<float><<<nblock, MAX_THREADS, 0, stream>>>(
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
num_all);
}
template <>
void launch_transform4d_0213<__half>(__half *output, const __half *input,
int batch_size, int seq_len,
int hidden_dim, int nhead, int trans_count,
cudaStream_t stream) {
hidden_dim >>= 3;
int head_dim = hidden_dim / nhead;
int num_all = batch_size * seq_len * trans_count * hidden_dim;
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
transform4d_0213<__half><<<nblock, MAX_THREADS, 0, stream>>>(
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
num_all);
}

View File

@ -1,406 +0,0 @@
#include "multihead_attention_1d.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/torch.h>
#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/Types.hpp>
#else
#include <c10d/Types.hpp>
#endif
#include <iostream>
#include "context.h"
#include "kernels.h"
template <typename T>
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens,
int max_seq_len, int hidden_size,
int num_heads,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm)
: _layer_id(layer_id),
_max_batch_tokens(max_batch_tokens),
_max_seq_len(max_seq_len),
_hidden_size(hidden_size),
_heads(num_heads),
_training(true),
_pre_or_postLayerNorm(pre_or_postLayerNorm),
_qkv_linear(
typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
_attn_out_linear(
typename FeedForward<T>::Config(hidden_size, hidden_size)),
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false),
_max_batch_tokens),
_softmax(typename Softmax<T>::Config(num_heads)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio),
_max_batch_tokens * _heads * _max_seq_len),
_attn_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio),
_max_batch_tokens * _hidden_size),
_attn_scores(typename StridedBatchGemm<T>::Config(
(T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T,
CUBLAS_OP_N)),
_attn_context(typename StridedBatchGemm<T>::Config(
T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
assert(_hidden_size % _heads == 0);
}
template <typename T>
MultiHeadAttention<T>::~MultiHeadAttention() {
free_mem_buffer();
}
template <typename T>
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr,
const T *input_mask_ptr,
T *output_ptr, T *buffer) {
T *q_tf_ptr = _qkv_ptr;
T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
if (_pre_or_postLayerNorm) {
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr,
_batch_tokens, _stream);
}
const T *gemmQKV_inp_ptr =
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer,
_cublasHandle);
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr,
_batch_size, _seq_len, 3, _heads / pg_size,
_hidden_size / _heads, _stream);
// attention scores, q*k
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr,
_cublasHandle);
// Softmax + Mask
_softmax.reset_size(_heads / pg_size);
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len,
_seq_len, _stream, true);
// attn prob dropout.
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr,
_batch_heads * _seq_len * _seq_len, _stream);
// attention context, score * v
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr,
_cublasHandle);
// [b, nh, s, ad] -> [b, s, nh, ad]
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len,
_hidden_size / pg_size, _heads / pg_size, 1,
_stream);
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr,
output_ptr, _cublasHandle);
// allreduce
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
} else {
auto data_type = torch::kFloat;
if (typeid(T) != typeid(float)) {
data_type = torch::kHalf;
}
auto output_tensor = torch::from_blob(
output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type));
std::vector<torch::Tensor> allreduce_tensors = {output_tensor};
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
work->wait();
}
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr,
_attn_ob_ptr, _batch_tokens, _hidden_size,
_stream);
if (!_pre_or_postLayerNorm) {
// in-place ln since ln-input will not be used in post-ln mode
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr,
_batch_tokens, _stream);
}
}
template <typename T>
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
T *out_ptr) {
_stream = Context::Instance().get_stream();
_cublasHandle = Context::Instance().get_cublashandle();
T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim
attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer);
}
template <typename T>
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
const T *input_mask_ptr,
const T *output_ptr,
const T *grad_output_ptr,
T *grad_input_ptr, T *buffer) {
cudaStream_t streams[2] = {_stream, _stream};
const T *q_tf_ptr = _qkv_ptr;
const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
// batch_dim = batch_size * seq_len * hidden_size
// buffer size: batch_dim * 3 + max(batch_dim * 3,
// batch_size * head_num * seq_len * seq_len)
T *grad_residual_ptr = buffer;
buffer += _batch_dim;
T *grad_input_buf_ptr = buffer; // batch_dim
T *grad_qkv_5d_ptr = buffer; // batch_dim * 3
buffer += 3 * _batch_dim / pg_size;
T *grad_qkv_4d_ptr = buffer; // batch_dim * 3
T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len
// buffer += max(3 * _batch_dim,
// batch_size * head_num * seq_len * seq_len);
if (_pre_or_postLayerNorm) {
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
grad_output_ptr, _batch_tokens,
_hidden_size, _stream);
} else {
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr,
grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr,
_attn_nb_ptr, _batch_tokens, streams);
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
grad_residual_ptr, _batch_tokens,
_hidden_size, _stream);
}
// bw of output project
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr,
_attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr,
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
false);
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size,
_seq_len, _hidden_size / pg_size, _heads / pg_size,
_stream);
// bw of score * v
_attn_context.Backward(
_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
_attn_prob_dropout.d_dropout(grad_softmax_ptr,
_batch_heads * _seq_len * _seq_len, _stream);
_softmax.reset_size(_heads / pg_size);
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len,
_seq_len, _stream);
// bw of q * k
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr,
_cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size,
grad_qkv_5d_ptr);
// [3, b, nh, s, ad] -> [b, s, 3, h]
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size,
_seq_len, _hidden_size / pg_size, _heads / pg_size,
3, _stream);
const T *gemmQKV_inp_ptr =
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr,
_attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr,
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
true);
// allreduce
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
} else {
auto data_type = torch::kFloat;
if (typeid(T) != typeid(float)) {
data_type = torch::kHalf;
}
auto grad_input_tensor =
torch::from_blob(grad_input_buf_ptr,
{int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type));
std::vector<torch::Tensor> allreduce_tensors = {grad_input_tensor};
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
work->wait();
}
if (_pre_or_postLayerNorm) {
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr,
grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr,
_attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
} else {
// FIXME later
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr,
_batch_size, _seq_len, _hidden_size, _stream);
}
}
template <typename T>
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr,
const T *input_ptr, const T *output_ptr,
const T *input_mask_ptr,
T *grad_input_ptr) {
_stream = Context::Instance().get_stream();
_cublasHandle = Context::Instance().get_cublashandle();
T *buffer = _shared_mem_ptr;
/*
buffer size needed by attn bw:
4 * _batch_dim + max(3 * _batch_dim,
_batch_size * _head_num * _seq_len * _seq_len);
*/
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr,
grad_input_ptr, buffer);
}
template <typename T>
void MultiHeadAttention<T>::SetTrainingMode(bool training) {
// Dropout will be skipped when not in training model.
_attn_prob_dropout.SetTrainingMode(training);
_attn_dropout.SetTrainingMode(training);
}
template <typename T>
T *MultiHeadAttention<T>::_shared_mem_ptr = nullptr;
template class MultiHeadAttention<float>;
template class MultiHeadAttention<__half>;
// x is torch::Tensor
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
static std::unordered_map<int, std::shared_ptr<void>> s_multihead_attention;
template <typename T>
int create_multihead_attention(int layer_id, int max_batch_tokens,
int max_seq_len, int hidden_dim, int num_heads,
float attn_prob_dropout_ratio,
float hidden_dropout_ratio,
bool pre_or_postLayerNorm,
c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Context::Instance().set_stream(stream);
auto layer = std::make_shared<MultiHeadAttention<T>>(
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads,
attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm);
layer->SetPG(pg_);
s_multihead_attention[layer_id] = layer;
std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
return 0;
}
template <typename T>
std::vector<torch::Tensor> multihead_attention_fw(
int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask,
const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias,
const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias,
const torch::Tensor &norm_weight, const torch::Tensor &norm_bias,
bool training_mode, bool prelayernorm) {
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
const T *input_ptr = (const T *)input.data_ptr();
const T *input_mask_ptr = (const T *)input_mask.data_ptr();
auto output = torch::empty_like(input);
T *out_ptr = (T *)output.data_ptr();
std::shared_ptr<MultiHeadAttention<T>> layer =
std::static_pointer_cast<MultiHeadAttention<T>>(
s_multihead_attention[layer_id]);
layer->set_cur_batch_shape(input.size(0), input.size(1));
layer->SetTrainingMode(training_mode);
layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr();
layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr();
layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr();
layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr();
layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr();
layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr();
layer->Forward(input_ptr, input_mask_ptr, out_ptr);
return {output};
}
template <typename T>
std::vector<torch::Tensor> multihead_attention_bw(
int layer_id, const torch::Tensor &grad_dec_output,
const torch::Tensor &output, const torch::Tensor &input,
const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight,
const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight,
const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight,
const torch::Tensor &norm_bias) {
auto g_output = grad_dec_output.contiguous();
CHECK_INPUT(g_output);
CHECK_INPUT(output);
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
auto grad_input = torch::empty_like(input);
auto grad_in_proj_weight = torch::empty_like(in_proj_weight);
auto grad_in_proj_bias = torch::empty_like(in_proj_bias);
auto grad_out_proj_weight = torch::empty_like(out_proj_weight);
auto grad_out_proj_bias = torch::empty_like(out_proj_bias);
auto grad_norm_weight = torch::empty_like(norm_weight);
auto grad_norm_bias = torch::empty_like(norm_bias);
// inputs.
const T *grad_dec_output_ptr = (const T *)g_output.data_ptr();
const T *input_ptr = (const T *)input.data_ptr();
const T *output_ptr = (const T *)output.data_ptr();
const T *input_mask_ptr = (const T *)input_mask.data_ptr();
// outputs.
T *grad_input_ptr = (T *)grad_input.data_ptr();
std::shared_ptr<MultiHeadAttention<T>> layer =
std::static_pointer_cast<MultiHeadAttention<T>>(
s_multihead_attention[layer_id]);
layer->set_cur_batch_shape(g_output.size(0), g_output.size(1));
layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr();
layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr();
layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr();
layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr();
layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr();
layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr();
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr,
grad_input_ptr);
return {grad_input, grad_in_proj_weight, grad_in_proj_bias,
grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight,
grad_norm_bias};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multihead_attention_fw_fp32", &multihead_attention_fw<float>,
"Multi-head Attention forward with fp32 (CUDA)");
m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>,
"Multi-head Attention forward with fp16 (CUDA)");
m.def("multihead_attention_bw_fp32", &multihead_attention_bw<float>,
"Multi-head Attention backward with fp32 (CUDA)");
m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>,
"Multi-head Attention backward with fp16 (CUDA)");
m.def("create_multihead_attention_fp32", &create_multihead_attention<float>,
"Create Multi-head Attention with fp32 (CUDA)");
m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>,
"Create Multi-head Attention with fp16 (CUDA)");
}

View File

@ -1,167 +0,0 @@
#pragma once
#include <c10/util/intrusive_ptr.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <torch/torch.h>
#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#else
#include <c10d/ProcessGroup.hpp>
#endif
#include <string>
#include <type_traits>
#include "cuda_util.h"
#include "dropout.h"
#include "feed_forward.h"
#include "normalize_layer.h"
#include "softmax.h"
#include "strided_batch_gemm.h"
template <typename T>
class MultiHeadAttention {
public:
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len,
int hidden_size, int num_heads, float attn_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm);
virtual ~MultiHeadAttention();
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
void Backward(const T *grad_output_ptr, const T *input_ptr,
const T *output_ptr, const T *input_mask_ptr,
T *grad_input_ptr);
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr,
T *buffer);
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr,
const T *output_ptr, const T *grad_output_ptr,
T *grad_input_attn_layer_bwptr, T *buffer);
void set_cur_batch_shape(int batch_size, int seq_len) {
_batch_size = batch_size;
_seq_len = seq_len;
_batch_tokens = batch_size * seq_len;
_batch_heads = batch_size * _heads / pg_size;
_batch_dim = _batch_tokens * _hidden_size;
_attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len);
}
void SetTrainingMode(bool training);
inline bool IsTrainingMode() const { return _training; }
void SetPG(c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
pg = pg_;
pg_size = 1;
if (pg != c10::detail::UniqueVoidPtr()) {
pg_size = pg->getSize();
}
allocate_mem_buffer();
}
// weights ptr
const T *_attn_qkvw_ptr;
const T *_attn_qkvb_ptr;
const T *_attn_ow_ptr;
const T *_attn_ob_ptr;
const T *_attn_nw_ptr;
const T *_attn_nb_ptr;
// grads ptr
T *_grad_attn_qkvw_ptr;
T *_grad_attn_qkvb_ptr;
T *_grad_attn_ow_ptr;
T *_grad_attn_ob_ptr;
T *_grad_attn_nw_ptr;
T *_grad_attn_nb_ptr;
private:
void allocate_mem_buffer() {
// allocate local gpu memory
if (_pre_or_postLayerNorm) {
_gemmQKV_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
} else {
_gemmQKV_inp_ptr = nullptr;
}
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
_soft_out_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_ctx_bufB_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
// buffer size needed by attn bw
size_t smem_size =
4 * _max_batch_tokens * _hidden_size / pg_size +
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
_max_batch_tokens * _heads / pg_size * _max_seq_len);
if (!_shared_mem_ptr) {
cuda_free(_shared_mem_ptr);
_shared_mem_ptr = cuda_malloc<T>(smem_size);
}
}
void free_mem_buffer() {
// free local gpu memory
cuda_free(_gemmQKV_inp_ptr);
cuda_free(_qkv_ptr);
cuda_free(_soft_out_ptr);
cuda_free(_ctx_bufB_ptr);
cuda_free(_attn_o_inp_ptr);
// free shared gpu memory between layers
cuda_free(_shared_mem_ptr);
_shared_mem_ptr = nullptr;
}
// const parameter between batch
const size_t _layer_id;
const size_t _hidden_size;
const size_t _heads;
const size_t _max_batch_tokens;
const size_t _max_seq_len;
const bool _pre_or_postLayerNorm;
// dynamic parameter between batch
size_t _batch_size;
size_t _seq_len;
size_t _batch_tokens;
size_t _batch_heads;
size_t _batch_dim;
bool _training;
cublasHandle_t _cublasHandle;
cudaStream_t _stream;
// layers
FeedForward<T> _qkv_linear;
FeedForward<T> _attn_out_linear;
Normalize_Layer<T> _attn_ln;
Softmax<T> _softmax;
Dropout<T> _attn_prob_dropout;
Dropout<T> _attn_dropout;
StridedBatchGemm<T> _attn_scores;
StridedBatchGemm<T> _attn_context;
// local GPU memory
T *_gemmQKV_inp_ptr;
T *_qkv_ptr;
T *_soft_out_ptr;
T *_ctx_bufB_ptr;
T *_attn_o_inp_ptr;
// shared GPU memory between layer
static T *_shared_mem_ptr;
c10::intrusive_ptr<c10d::ProcessGroup> pg;
int pg_size;
};

View File

@ -1,8 +0,0 @@
#include <torch/extension.h>
#include "linear.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
"Linear SiLU (INT8)");
}

View File

@ -1,162 +0,0 @@
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu
#include "linear.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/epilogue/thread/linear_combination_silu.h>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <iostream>
#include <torch/torch.h>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
) {
auto M = input.size(0);
auto N = weight.size(0);
auto K = input.size(1);
using ElementOutput = float;
using ElementAccumulator = int32_t;
using ElementComputeEpilogue = float;
using ElementInputA = int8_t; // <- data type of elements in input matrix A
using ElementInputB = int8_t; // <- data type of elements in input matrix B
// The code section below describes matrix layout of input and output
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
// for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
#if CUDA_ARCH >= 800
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
#elif CUDA_ARCH >= 750
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
EpilogueOp>;
#elif CUDA_ARCH >= 700
#define USE_TORCH_SILU
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
#else
#error "Unsupported cuda arch"
#endif
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
auto out = bias.to(device).view({1, -1}).repeat({M, 1});
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{alpha, beta}, 1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
#ifdef USE_TORCH_SILU
#undef USE_TORCH_SILU
out = torch::silu(out);
#endif
return out;
}

View File

@ -1,12 +0,0 @@
#include <torch/torch.h>
#include <torch/types.h>
#include <cstdint>
#include <iostream>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
);

View File

@ -1,3 +0,0 @@
from .mha import ColoAttention
__all__ = ["ColoAttention"]

View File

@ -1,80 +0,0 @@
import warnings
from typing import Optional
import torch
def is_ampere_or_better_gpu():
if torch.cuda.is_available():
device = torch.device("cuda")
properties = torch.cuda.get_device_properties(device)
if properties.major >= 8: # Ampere GPUs or newer
return True
return False
# "Check Ampere GPUs or newer"
HAS_FLASH_ATTN = False
if is_ampere_or_better_gpu():
HAS_FLASH_ATTN = True
else:
warnings.warn("FlashAttention only supports Ampere GPUs or newer.")
HAS_FLASH_ATTN = False
try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
HAS_FLASH_ATTN = True
except ImportError:
warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention")
HAS_FLASH_ATTN = False
if HAS_FLASH_ATTN:
pass
from .utils import SeqLenInfo
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: SeqLenInfo,
seq_len_info_kv: SeqLenInfo,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
"""
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
if padded:
if seq_len_info_kv == None:
seq_len_info_kv = seq_len_info_q
attn_out = flash_attn_varlen_func(
q,
k,
v,
seq_len_info_q.cu_seqlens,
seq_len_info_kv.cu_seqlens,
seq_len_info_q.max_seqlen,
seq_len_info_kv.max_seqlen,
dropout_p,
scale,
causal,
)
else:
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
return attn_out

View File

@ -1,70 +0,0 @@
import warnings
HAS_MEM_EFF_ATTN = False
try:
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
HAS_MEM_EFF_ATTN = True
except ImportError:
warnings.warn("please install xformers from https://github.com/facebookresearch/xformers")
HAS_MEM_EFF_ATTN = False
if HAS_MEM_EFF_ATTN:
"""
A general attention module using the flash attention kernels from xformers:
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
"""
from typing import Optional
import torch
from .utils import SeqLenInfo
allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp:
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
def mem_eff_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: SeqLenInfo,
seq_len_info_kv: SeqLenInfo,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
attn_bias = None
if padded: # bert style
if not causal:
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
elif causal: # gpt style
attn_bias = LowerTriangularMask()
if bias is not None: # alibi / relative position embedding
assert allow_alibi, "flash attention with bias is not supported in this system."
assert causal, "attention with bias is only supported for causal attention so far."
attn_bias = attn_bias.add_bias(bias)
if padded:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
# shape: (b*s, n, d)
if padded:
out = out.squeeze(0)
return out

View File

@ -1,113 +0,0 @@
import math
from typing import Optional
import torch
from einops import rearrange
from ..scaled_softmax import AttnMaskType
from .flash_attn_2 import HAS_FLASH_ATTN
from .mem_eff_attn import HAS_MEM_EFF_ATTN
from .utils import Repad, SeqLenInfo, Unpad
if HAS_FLASH_ATTN:
from .flash_attn_2 import flash_attention
if HAS_MEM_EFF_ATTN:
from .mem_eff_attn import mem_eff_attention
class ColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
super().__init__()
assert (
embed_dim % num_heads == 0
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN:
raise Exception("flash attention can not support!")
@staticmethod
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return Unpad.apply(tensor, indices)
@staticmethod
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
return Repad.apply(tensor, indices, batch_size, seq_len)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: Optional[AttnMaskType] = None,
bias: Optional[torch.Tensor] = None,
):
attn = None
if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None:
attn = flash_attention
else:
attn = mem_eff_attention
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
causal = attn_mask_type is not None and attn_mask_type.value > 1
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
# unpad
seq_len_info_q = None
seq_len_info_kv = None
if padded:
# bert style, unpad process
assert (
attn_mask is not None
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
assert attn_mask.dim() == 2, (
"attention mask is supposed to have shape (batch_size, seq_len), "
+ f"but got {attn_mask.dim()} dimensions."
)
# bert style
if tgt_len == src_len:
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
if batch_size > 1:
query, key, value = self.unpad(
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
).unbind(dim=1)
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
seq_len_info_kv = seq_len_info_q
else:
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
if batch_size > 1:
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
dim=1
)
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
out = attn(
query,
key,
value,
seq_len_info_q,
seq_len_info_kv,
dropout_p=self.dropout,
scale=self.scale,
causal=causal,
padded=padded,
)
# repad
if padded:
if batch_size > 1:
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
out = rearrange(out, "b s h d -> b s (h d)")
return out

View File

@ -1,82 +0,0 @@
from dataclasses import dataclass
from typing import Iterable, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.utils.device import get_current_device
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)

View File

@ -1,338 +0,0 @@
import math
from dataclasses import dataclass
import torch
from torch import nn
from torch.autograd import Function
def check_config(config):
if config.hidden_size % config.nhead != 0:
raise Exception("hidden_size % nhead != 0")
factor = 8 if config.fp16 else 4
upbound = factor * 1024 * 4
if config.hidden_size > upbound:
# as required by ln backward kernel currently
raise Exception(f"hidden_size > {upbound}")
head_dim = config.hidden_size // config.nhead
if head_dim % factor != 0:
# as required by reshape kernel
raise Exception(f"head_dim({head_dim}) % {factor} != 0")
def calc_offset(sizes):
offsets = [0]
tmp = 0
for x in sizes:
tmp += x
offsets.append(tmp)
return offsets
colossal_multihead_attention = None
@dataclass
class Config:
max_batch_tokens: int # max batch token numbers
max_seq_len: int # max sequence length
hidden_size: int # size of transformer hidden layers
nhead: int # number of heads in attention
attn_prob_dropout_ratio: float # attention score dropout ratio
hidden_dropout_ratio: float # dropout ration before residual
norm_first: bool # norm_first
fp16: bool # fp16 precision
class MultiHeadAttention1DFunc(Function):
@staticmethod
def forward(
ctx,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
config,
):
cuda_module = colossal_multihead_attention
forward_func = (
cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32
)
if config.fp16:
input = input.to(torch.half)
input_mask = input_mask.to(torch.half)
(output,) = forward_func(
config.layer_id,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
config.training,
config.norm_first,
)
if config.is_grad_enabled and config.training:
ctx.save_for_backward(
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
)
ctx.config = config
return output
@staticmethod
def backward(ctx, grad_output):
assert ctx.config.training
cuda_module = colossal_multihead_attention
backward_func = (
cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32
)
(
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
) = ctx.saved_tensors
grad_input = None
grad_in_proj_weight = None
grad_in_proj_bias = None
grad_out_proj_weight = None
grad_out_proj_bias = None
grad_norm_weight = None
grad_norm_bias = None
if ctx.config.fp16:
grad_output = grad_output.to(torch.half)
output = output.to(torch.half)
input = input.to(torch.half)
input_mask = input_mask.to(torch.half)
(
grad_input,
grad_in_proj_weight,
grad_in_proj_bias,
grad_out_proj_weight,
grad_out_proj_bias,
grad_norm_weight,
grad_norm_bias,
) = backward_func(
ctx.config.layer_id,
grad_output,
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
)
return (
grad_input,
None,
grad_in_proj_weight,
grad_in_proj_bias,
grad_out_proj_weight,
grad_out_proj_bias,
grad_norm_weight,
grad_norm_bias,
None,
)
class MultiHeadAttention(nn.Module):
"""Initialize the MultiHeadAttention.
Static variable:
layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
Arguments:
hidden_size: Total dimension of hidden_size.
nhead: Number of parallel attention heads.
batch_size: Batch Size for one forward
max_seq_len: Max length of input sequence
dropout: Dropout probability
norm_first: perform LayerNorms before attention
"""
layer_id = 0
def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
super(MultiHeadAttention, self).__init__()
self.config = Config(
batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16
)
check_config(self.config)
self.pg = pg
self.pg_size = 1
if self.pg:
self.pg_size = pg.size()
self.config.layer_id = MultiHeadAttention.layer_id
MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1
# Load cuda modules if needed
global colossal_multihead_attention
if colossal_multihead_attention is None:
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
multihead_attention = MultiHeadAttnBuilder().load()
colossal_multihead_attention = multihead_attention
# create the layer in cuda kernels.
cuda_module = colossal_multihead_attention
create_layer_func = (
cuda_module.create_multihead_attention_fp16
if self.config.fp16
else cuda_module.create_multihead_attention_fp32
)
create_layer_func(
self.config.layer_id,
self.config.max_batch_tokens,
self.config.max_seq_len,
self.config.hidden_size,
self.config.nhead,
self.config.attn_prob_dropout_ratio,
self.config.hidden_dropout_ratio,
self.config.norm_first,
self.pg,
)
hs = self.config.hidden_size
self.precision = torch.float32
if self.config.fp16:
self.precision = torch.half
self.hs_per_rank = int(hs / self.pg_size)
self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs))
self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank))
self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank))
self.out_proj_bias = nn.Parameter(torch.Tensor(hs))
self.norm_weight = nn.Parameter(torch.Tensor(hs))
self.norm_bias = nn.Parameter(torch.Tensor(hs))
self.reset_parameters()
torch.cuda.empty_cache()
def calc_bound(self, w):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w)
bound = 1.0 / math.sqrt(fan_in)
return bound
def reset_parameters(self):
hs = self.config.hidden_size
nn.init.zeros_(self.out_proj_bias)
nn.init.ones_(self.norm_weight)
nn.init.zeros_(self.norm_bias)
if self.pg_size > 1:
rank_in_pg = torch.distributed.get_rank(self.pg)
attn_qkvw_global = torch.empty(hs * 3, hs)
attn_qkvb_global = torch.empty(hs * 3)
nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0))
bound = self.calc_bound(attn_qkvw_global)
nn.init.uniform_(attn_qkvb_global, -bound, bound)
attn_qkvw_global = attn_qkvw_global.cuda()
attn_qkvb_global = attn_qkvb_global.cuda()
torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg)
torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg)
attn_qkvw_global = attn_qkvw_global.cpu()
attn_qkvb_global = attn_qkvb_global.cpu()
with torch.no_grad():
self.in_proj_weight.copy_(
attn_qkvw_global.view(3, hs, hs)[
:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), :
]
)
self.in_proj_bias.copy_(
attn_qkvb_global.view(3, hs)[
:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)
]
)
attn_ow_global = torch.empty(hs, hs)
nn.init.xavier_uniform_(attn_ow_global, 1.0)
attn_ow_global = attn_ow_global.cuda()
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
attn_ow_global = attn_ow_global.cpu()
with torch.no_grad():
self.out_proj_weight.copy_(
attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)]
)
else:
attn_qkvw = self.in_proj_weight.view(-1, hs)
nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0))
bound = self.calc_bound(attn_qkvw)
nn.init.uniform_(self.in_proj_bias, -bound, bound)
nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
def state_dict(self, destination=None, prefix="", keep_vars=False):
destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars)
return destination
def forward(self, hidden_states, encoder_padding_mask):
self.config.training = self.training
self.config.is_grad_enabled = torch.is_grad_enabled()
hidden_states = hidden_states.contiguous()
encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()
bs, sl, dim = hidden_states.size()
if bs * sl > self.config.max_batch_tokens:
raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
if sl > self.config.max_seq_len:
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
if len(encoder_padding_mask.size()) == 1:
assert bs == 1 and sl == encoder_padding_mask.size(0)
else:
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
output = MultiHeadAttention1DFunc.apply(
hidden_states,
encoder_padding_mask,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.norm_weight,
self.norm_bias,
self.config,
)
return output.to(self.precision)

View File

@ -0,0 +1 @@
../../extensions

View File

@ -1,7 +1,7 @@
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
from colossalai.utils import get_current_device
from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl
@ -46,11 +46,13 @@ def warmup_jit_fusion(
):
"""Compile JIT functions before the main training steps"""
embed = Embedding(vocab_size, hidden_size).to(get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device())
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device())
x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device())
x = torch.randint(
vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device()
)
x = embed(x)
y, y_bias = linear_1(x)
z, z_bias = linear_2(y)
@ -58,8 +60,8 @@ def warmup_jit_fusion(
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
for _ in range(10):
bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device())
input_ = torch.rand_like(y, dtype=dtype, device=get_current_device())
bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device())
input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device())
bias.requires_grad, input_.requires_grad = bias_grad, input_grad
bias_gelu_impl(input_, bias)
@ -69,9 +71,9 @@ def warmup_jit_fusion(
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
for _ in range(10):
input_ = torch.rand_like(z, dtype=dtype, device=get_current_device())
residual = torch.rand_like(x, dtype=dtype, device=get_current_device())
bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device())
input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device())
residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device())
bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device())
input_.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad

View File

@ -0,0 +1,109 @@
import warnings
from typing import List
from .extensions import (
CpuAdamArmExtension,
CpuAdamX86Extension,
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
FusedOptimizerCudaExtension,
LayerNormCudaExtension,
MoeCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
)
from .extensions.base_extension import _Extension
__all__ = [
"KernelLoader",
"CPUAdamLoader",
"LayerNormLoader",
"MoeLoader",
"FusedOptimizerLoader",
"ScaledMaskedSoftmaxLoader",
"ScaledUpperTriangleMaskedSoftmaxLoader",
]
class KernelLoader:
"""
An abstract class which offers encapsulation to the kernel loading process.
Usage:
kernel_loader = KernelLoader()
kernel = kernel_loader.load()
"""
REGISTRY: List[_Extension] = []
@classmethod
def register_extension(cls, extension: _Extension):
"""
This classmethod is an extension point which allows users to register their customized
kernel implementations to the loader.
Args:
extension (_Extension): the extension to be registered.
"""
cls.REGISTRY.append(extension)
def load(self, ext_name: str = None):
"""
Load the kernel according to the current machine.
Args:
ext_name (str): the name of the extension to be loaded. If not specified, the loader
will try to look for an kernel available on the current machine.
"""
exts = [ext_cls() for ext_cls in self.__class__.REGISTRY]
# look for exts which can be built/loaded on the current machine
if ext_name:
usable_exts = list(filter(lambda ext: ext.name == ext_name, exts))
else:
usable_exts = []
for ext in exts:
if ext.is_hardware_available():
# make sure the machine is compatible during kernel loading
ext.assert_hardware_compatible()
usable_exts.append(ext)
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
if len(usable_exts) > 1:
# if more than one usable kernel is found, we will try to load the kernel with the highest priority
usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True)
warnings.warn(
f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}"
)
return usable_exts[0].load()
class CPUAdamLoader(KernelLoader):
REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension]
class LayerNormLoader(KernelLoader):
REGISTRY = [LayerNormCudaExtension]
class MoeLoader(KernelLoader):
REGISTRY = [MoeCudaExtension]
class FusedOptimizerLoader(KernelLoader):
REGISTRY = [FusedOptimizerCudaExtension]
class ScaledMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]
class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension]
class FlashAttentionLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]

View File

@ -1 +0,0 @@
../../op_builder

View File

@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes
@ -28,7 +28,7 @@ def load_fused_optim():
global fused_optim
if fused_optim is None:
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):

View File

@ -1,18 +1,19 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.utils.device import autocast
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.legacy.utils import clip_grad_norm_fp32
from ._grad_scaler import GradScaler
autocast = get_accelerator().autocast
class TorchAMPOptimizer(OptimizerWrapper):
"""A wrapper class which integrate Pytorch AMP with an optimizer

View File

@ -8,9 +8,9 @@ from typing import List, Tuple, Union
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
if isinstance(recv_shapes, torch.Size):
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
buffer_recv = torch.empty(
recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
)
return buffer_recv, recv_split
buffer_recv = []
for recv_shape in recv_shapes:
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
tensor_recv = torch.empty(
recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
)
buffer_recv.append(tensor_recv)
return buffer_recv, recv_split

View File

@ -3,9 +3,9 @@
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device, synchronize
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
current_rank = gpc.get_global_rank()
tensor_recv_prev = torch.empty(
buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype
buffer_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=tensor_send_next.dtype
)
# send to next rank
@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
req.wait()
# To protect against race condition when using batch_isend_irecv().
synchronize()
get_accelerator().synchronize()
return tensor_recv_prev

View File

@ -3,9 +3,9 @@ from typing import List, Tuple, Union
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
TensorShape = Union[torch.Size, List[int], Tuple[int]]
@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
if isinstance(obj, torch.Tensor):
send_obj_nums = torch.tensor(1, **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
recv_obj_nums = torch.empty((), **tensor_kwargs)
dist.recv(recv_obj_nums, prev_rank)
if recv_obj_nums.item() == 1:

View File

@ -6,8 +6,8 @@ from typing import Callable, Iterable
import torch
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
class BaseSchedule(ABC):
@ -29,12 +29,12 @@ class BaseSchedule(ABC):
def _move_tensor(element):
if torch.is_tensor(element):
if not element.is_cuda:
return element.to(get_current_device()).detach()
return element.to(get_accelerator().get_current_device()).detach()
return element
def _move_to_device(self, data):
if isinstance(data, torch.Tensor):
data = data.to(get_current_device())
data = data.to(get_accelerator().get_current_device())
elif isinstance(data, (list, tuple)):
data_to_return = []
for element in data:

View File

@ -7,12 +7,12 @@ from typing import Callable, List, Tuple, Union
import torch.cuda
import colossalai.legacy.communication as comm
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device
from ._base_schedule import BaseSchedule
@ -352,7 +352,7 @@ class PipelineSchedule(BaseSchedule):
output_objs = []
return_tensors = []
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device())
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else:
accum_loss = None
# Used for tensor meta information communication
@ -584,7 +584,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if not forward_only:
output_obj_grads = [[] for _ in range(len(model))]
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device())
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else:
accum_loss = None

View File

@ -6,10 +6,10 @@ from typing import Iterable, Tuple
import torch.cuda
import colossalai.legacy.communication.p2p_v2 as comm
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine import Engine
from colossalai.utils.device import get_current_device
from ._pipeline_schedule import PipelineSchedule
@ -99,7 +99,7 @@ class PipelineScheduleV2(PipelineSchedule):
output_objs = []
return_tensors = []
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device())
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else:
accum_loss = None

View File

@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from colossalai.accelerator import get_accelerator
from colossalai.context import Config, ConfigException
from colossalai.interface import OptimizerWrapper
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
@ -34,7 +35,6 @@ from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence
from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
def get_default_parser():
@ -309,9 +309,9 @@ def initialize(
else:
if isinstance(model, nn.Module):
# first sync model across dp ranks
model.to(get_current_device())
model.to(get_accelerator().get_current_device())
elif isinstance(model, Callable):
model = model().to(get_current_device())
model = model().to(get_accelerator().get_current_device())
# optimizer maybe a optimizer_cls
if isinstance(optimizer, Callable):

View File

@ -3,8 +3,8 @@ from typing import Callable
from torch import dtype, nn
from colossalai.accelerator import get_accelerator
from colossalai.nn import init
from colossalai.utils import get_current_device
from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
@ -83,7 +83,7 @@ class Embedding(ColossalaiModule):
embed = (
nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs)
.to(dtype)
.to(get_current_device())
.to(get_accelerator().get_current_device())
)
weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
elif num_embeddings <= vocab_parallel_limit:

View File

@ -1,6 +1,6 @@
from torch import nn
from colossalai.utils import get_current_device
from colossalai.accelerator import get_accelerator
from ..parallel_1d import LayerNorm1D
from ..parallel_2d import LayerNorm2D
@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule):
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel is None:
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_accelerator().get_current_device())
else:
norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
super().__init__(norm)

View File

@ -10,7 +10,7 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter
from colossalai.kernel import LayerNorm
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.context.parallel_context import global_context as gpc
@ -22,7 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm
from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
@ -221,7 +221,7 @@ class Classifier1D(ParallelLayer):
# Parameters.
# Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
@ -357,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer):
# Parameters.
# Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
@ -499,7 +499,7 @@ class Linear1D_Col(ParallelLayer):
# Parameters.
# Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias:
@ -638,7 +638,7 @@ class Linear1D_Row(ParallelLayer):
# Parameters.
# Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1:
@ -802,7 +802,9 @@ class Embedding1D(ParallelLayer):
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.reset_parameters(weight_initializer)
@ -912,7 +914,11 @@ class VocabParallelEmbedding1D(ParallelLayer):
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)
torch.empty(
(self.num_embeddings_per_partition, self.embed_dim),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
self.reset_parameters(weight_initializer)

View File

@ -5,10 +5,10 @@ import torch.distributed as dist
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
def matmul_2d(
@ -250,7 +250,7 @@ class Matmul_AB_2D(torch.autograd.Function):
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
@ -399,7 +399,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
@ -556,7 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor
# 2 is enough for all cases

View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.core import global_context as gpc
@ -18,7 +19,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
@ -82,7 +82,7 @@ class Linear2D(ParallelLayer):
self.hidden_size_per_partition = divide(self.out_features, self.summa_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
)
@ -259,7 +259,7 @@ class LayerNorm2D(ParallelLayer):
self.partitioned_partition = divide(normalized_shape, self.summa_dim**2)
# create parameters
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
if bias:
@ -438,18 +438,24 @@ class PatchEmbedding2D(ParallelLayer):
self.weight = Parameter(
torch.empty(
(self.embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
self.bias = Parameter(
torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
)
self.cls_token = Parameter(
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)
torch.zeros(
(1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.pos_embed = Parameter(
torch.zeros(
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype
(1, self.num_patches + 1, self.embed_size_per_partition),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
@ -619,7 +625,9 @@ class Embedding2D(ParallelLayer):
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.reset_parameters(weight_initializer)
@ -758,7 +766,7 @@ class VocabParallelEmbedding2D(ParallelLayer):
self.weight = Parameter(
torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
@ -895,11 +903,18 @@ class Classifier2D(ParallelLayer):
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)
torch.empty(
self.num_classes,
self.input_size_per_partition,
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
self.bias = Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else:
self.bias = None
@ -1052,7 +1067,7 @@ class VocabParallelClassifier2D(ParallelLayer):
self.output_size_per_partition = divide(num_classes, self.summa_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None:
self.weight = weight
self.has_weight = False

View File

@ -5,10 +5,10 @@ import torch.distributed as dist
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
def get_parallel_group(parallel_mode: ParallelMode):
@ -205,7 +205,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
@ -362,7 +362,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
@ -527,7 +527,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor
# 2 is enough for all cases
@ -661,7 +661,9 @@ class _Add_Bias_2p5D(torch.autograd.Function):
if row_rank == 0:
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
bias_temp = torch.zeros(
output_size_per_partition, dtype=bias.dtype, device=get_accelerator().get_current_device()
)
src_rank = (
col_rank
+ dep_rank * tesseract_dim**2
@ -984,7 +986,7 @@ class SplitFirst(torch.autograd.Function):
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_accelerator().get_current_device())
dist.all_gather(
list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode)
)

View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.core import global_context as gpc
@ -19,7 +20,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
@ -84,7 +84,7 @@ class Linear2p5D(ParallelLayer):
self.hidden_size_per_partition = divide(out_features, self.tesseract_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
)
@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
# create parameters
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
if bias:
@ -451,18 +451,24 @@ class PatchEmbedding2p5D(ParallelLayer):
self.weight = Parameter(
torch.empty(
(self.embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
self.bias = Parameter(
torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
)
self.cls_token = Parameter(
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)
torch.zeros(
(1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.pos_embed = Parameter(
torch.zeros(
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype
(1, self.num_patches + 1, self.embed_size_per_partition),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
@ -632,7 +638,9 @@ class Embedding2p5D(ParallelLayer):
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.reset_parameters(weight_initializer)
@ -772,7 +780,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
self.weight = Parameter(
torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
@ -910,11 +918,18 @@ class Classifier2p5D(ParallelLayer):
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)
torch.empty(
self.num_classes,
self.input_size_per_partition,
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
self.bias = Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else:
self.bias = None
@ -1068,7 +1083,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None:
self.weight = weight
self.has_weight = False

View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import all_reduce, broadcast
from colossalai.legacy.constants import (
INPUT_GROUP_3D,
@ -27,7 +28,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (
@ -69,11 +69,13 @@ class LayerNorm3D(ParallelLayer):
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
torch.ones(self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
)
if bias:
self.bias = Parameter(
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
torch.zeros(
self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
)
)
else:
self.bias = None
@ -202,13 +204,15 @@ class Linear3D(ParallelLayer):
torch.empty(
self.in_features_per_partition,
self.out_features_per_partition,
device=get_current_device(),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
if bias:
self.bias = Parameter(
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
torch.zeros(
self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
)
)
else:
self.bias = None
@ -380,11 +384,18 @@ class Classifier3D(ParallelLayer):
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)
torch.empty(
self.num_classes,
self.in_features_per_partition,
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
self.bias = Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else:
self.bias = None
@ -523,14 +534,16 @@ class VocabParallelClassifier3D(ParallelLayer):
torch.empty(
self.out_features_per_partition,
self.in_features_per_partition,
device=get_current_device(),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
self.has_weight = True
if bias:
self.bias = Parameter(
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
torch.zeros(
self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
)
)
else:
self.bias = None
@ -705,16 +718,24 @@ class PatchEmbedding3D(ParallelLayer):
self.weight = nn.Parameter(
torch.empty(
(embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype
(embed_size_per_partition, in_chans, *self.patch_size),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(
torch.empty(embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
)
self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
torch.zeros((1, 1, embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype)
)
self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
torch.zeros(
(1, self.num_patches + 1, embed_size_per_partition),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
@ -880,7 +901,9 @@ class Embedding3D(ParallelLayer):
self.embed_kwargs = kwargs
self.weight = nn.Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.reset_parameters(weight_initializer)
@ -1019,7 +1042,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
self.weight = Parameter(
torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
)

Some files were not shown because too many files have changed in this diff Show More