mirror of https://github.com/hpcaitech/ColossalAI
merge commit
commit
c565519913
|
@ -140,7 +140,7 @@ jobs:
|
|||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
CUDA_EXT=1 pip install -v -e .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
pip install -r requirements/requirements-test.txt
|
||||
|
||||
- name: Store Colossal-AI Cache
|
||||
|
@ -160,9 +160,7 @@ jobs:
|
|||
--ignore tests/test_gptq \
|
||||
--ignore tests/test_infer_ops \
|
||||
--ignore tests/test_legacy \
|
||||
--ignore tests/test_moe \
|
||||
--ignore tests/test_smoothquant \
|
||||
--ignore tests/test_checkpoint_io \
|
||||
tests/
|
||||
env:
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
|
|
|
@ -12,7 +12,7 @@ jobs:
|
|||
if: github.repository == 'hpcaitech/ColossalAI'
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
|
@ -23,6 +23,7 @@ jobs:
|
|||
ngpu=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||
endIndex=$(($ngpu-1))
|
||||
for i in $(seq 0 $endIndex);
|
||||
do
|
||||
gpu_used=$(nvidia-smi -i $i --query-gpu=memory.used --format=csv,noheader,nounits)
|
||||
[ "$gpu_used" -gt "2000" ] && avai=false
|
||||
done
|
||||
|
@ -54,7 +55,7 @@ jobs:
|
|||
if: steps.check-avai.outputs.avai == 'true'
|
||||
run: |
|
||||
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
|
||||
CUDA_EXT=1 pip install -v -e .
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
|
||||
pip install -r requirements/requirements-test.txt
|
||||
|
||||
|
|
|
@ -45,9 +45,9 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
|
|
@ -77,7 +77,7 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
||||
timeout-minutes: 20
|
||||
concurrency:
|
||||
|
|
|
@ -34,7 +34,7 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
|
|
|
@ -18,7 +18,7 @@ jobs:
|
|||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
|
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
|
|
|
@ -19,7 +19,7 @@ jobs:
|
|||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
|
||||
runs-on: [self-hosted, gpu]
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
|
||||
volumes:
|
||||
- /data/scratch/test_data_colossalqa:/data/scratch/test_data_colossalqa
|
||||
- /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
include *.txt README.md
|
||||
recursive-include requirements *.txt
|
||||
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi
|
||||
recursive-include op_builder *.py
|
||||
recursive-include extensions *.py *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi
|
||||
|
|
|
@ -142,7 +142,7 @@ distributed training and inference in a few lines.
|
|||
[[Modelscope model weights]](https://www.modelscope.cn/models/colossalai/Colossal-LLaMA-2-13b-base/summary)
|
||||
|
||||
| Model | Backbone | Tokens Consumed | MMLU (5-shot) | CMMLU (5-shot)| AGIEval (5-shot) | GAOKAO (0-shot) | CEval (5-shot) |
|
||||
| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-------------: | :-------------: |
|
||||
| :-----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-------------: | :-------------: |
|
||||
| Baichuan-7B | - | 1.2T | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 |
|
||||
| Baichuan-13B-Base | - | 1.4T | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 |
|
||||
| Baichuan2-7B-Base | - | 2.6T | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 |
|
||||
|
@ -160,6 +160,7 @@ distributed training and inference in a few lines.
|
|||
| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | 49.96 | 41.10 | 39.83 | 33.00 | - |
|
||||
| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | 50.25 | 40.99 | 40.04 | 30.54 | - |
|
||||
| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 |
|
||||
| **Colossal-LLaMA-2-13b-base** | Llama-2-13B | **0.025T** | 56.42 | 61.80 | 54.69 | 69.53 | 60.3 |
|
||||
|
||||
|
||||
### ColossalChat
|
||||
|
|
|
@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .base import OnPolicyTrainer
|
||||
from .callbacks import Callback
|
||||
|
@ -105,7 +105,7 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
self.critic_optim = critic_optim
|
||||
|
||||
self.offload_inference_models = offload_inference_models
|
||||
self.device = get_current_device()
|
||||
self.device = get_accelerator().get_current_device()
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch.nn as nn
|
|||
import colossalai
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||
|
||||
from .ddp import DDPStrategy
|
||||
|
@ -158,9 +157,19 @@ class GeminiStrategy(DDPStrategy):
|
|||
|
||||
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
|
||||
|
||||
# colossalai has changed api for get_current_device in 0.3.4 version or newer
|
||||
try:
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
chunk_init_device = get_accelerator().get_current_device()
|
||||
except:
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
chunk_init_device = get_current_device()
|
||||
|
||||
# NOTE: dist should be initialized before calling get_current_device()
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
chunk_init_device=get_current_device(),
|
||||
chunk_init_device=chunk_init_device,
|
||||
placement_policy=placement_policy,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
|
|
|
@ -6,12 +6,12 @@ Initialize new tokenizer for continual pre-training
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
|
|
@ -16,7 +16,10 @@ import torch
|
|||
|
||||
|
||||
def unwrap(model):
|
||||
return model.unwrap().module
|
||||
if hasattr(model, "module"):
|
||||
return unwrap_model(model.module)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def neftune_post_forward_hook(module, input, output):
|
||||
|
|
|
@ -4,41 +4,34 @@
|
|||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_llama2.dataset.loader import (
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import (
|
||||
GeminiPlugin,
|
||||
LowLevelZeroPlugin,
|
||||
HybridParallelPlugin,
|
||||
)
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from colossal_llama2.dataset.loader import (
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
|
@ -215,9 +208,18 @@ def main() -> None:
|
|||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
)
|
||||
|
||||
# colossalai has changed api for get_current_device in 0.3.4 version or newer
|
||||
try:
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
current_device = get_accelerator().get_current_device()
|
||||
except:
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
current_device = get_current_device()
|
||||
|
||||
init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||
# Freeze part of parameters.
|
||||
|
@ -320,7 +322,7 @@ def main() -> None:
|
|||
initial=start_step,
|
||||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
batch_output = model(**batch)
|
||||
|
||||
|
@ -372,9 +374,7 @@ def main() -> None:
|
|||
# Final save.
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
|
||||
)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
|
|
@ -136,6 +136,19 @@ class ColossalLLM(LLM):
|
|||
"""Get the identifying parameters."""
|
||||
return {"n": self.n}
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
"""Return the ordered ids of the tokens in a text.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
A list of ids corresponding to the tokens in the text, in order they occur
|
||||
in the text.
|
||||
"""
|
||||
# use the colossal llm's tokenizer instead of langchain's cached GPT2 tokenizer
|
||||
return self.api.tokenizer.encode(text)
|
||||
|
||||
|
||||
class VllmLLM(LLM):
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
|
||||
from . import accelerator
|
||||
|
||||
try:
|
||||
# .version will be created by setup.py
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# 🚀 Accelerator
|
||||
|
||||
## 🔗 Table of Contents
|
||||
|
||||
- [🚀 Accelerator](#-accelerator)
|
||||
- [🔗 Table of Contents](#-table-of-contents)
|
||||
- [📚 Introduction](#-introduction)
|
||||
- [📌 Design and Acknowledgement](#-design-and-acknowledgement)
|
||||
|
||||
## 📚 Introduction
|
||||
|
||||
This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `auto_set_accelerator()` API.
|
||||
|
||||
## 📌 Design and Acknowledgement
|
||||
|
||||
Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https://www.deepspeed.ai/tutorials/accelerator-abstraction-interface/). We found that it is a very well-designed and well-structured module that can be easily integrated into our project. We would like to thank the DeepSpeed team for their great work.
|
||||
|
||||
We implemented this accelerator module from scratch. At the same time, we have implemented our own modifications:
|
||||
1. we updated the accelerator API names to be aligned with PyTorch's native API names.
|
||||
2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled.
|
|
@ -0,0 +1,15 @@
|
|||
from .api import auto_set_accelerator, get_accelerator, set_accelerator
|
||||
from .base_accelerator import BaseAccelerator
|
||||
from .cpu_accelerator import CpuAccelerator
|
||||
from .cuda_accelerator import CudaAccelerator
|
||||
from .npu_accelerator import NpuAccelerator
|
||||
|
||||
__all__ = [
|
||||
"get_accelerator",
|
||||
"set_accelerator",
|
||||
"auto_set_accelerator",
|
||||
"BaseAccelerator",
|
||||
"CudaAccelerator",
|
||||
"NpuAccelerator",
|
||||
"CpuAccelerator",
|
||||
]
|
|
@ -0,0 +1,71 @@
|
|||
#!/usr/bin/env python
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
from .cpu_accelerator import CpuAccelerator
|
||||
from .cuda_accelerator import CudaAccelerator
|
||||
from .npu_accelerator import NpuAccelerator
|
||||
|
||||
__all__ = ["set_accelerator", "auto_set_accelerator", "get_accelerator"]
|
||||
|
||||
|
||||
_ACCELERATOR = None
|
||||
|
||||
|
||||
# we use ordered dictionary here to associate the
|
||||
# order with device check priority
|
||||
# i.e. auto_set_accelerator will check cuda first
|
||||
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator)
|
||||
|
||||
|
||||
def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
|
||||
"""
|
||||
Set the global accelerator for the current process.
|
||||
|
||||
Args:
|
||||
accelerator (Union[str, BaseAccelerator]): the type of accelerator to which the current device belongs.
|
||||
"""
|
||||
|
||||
global _ACCELERATOR
|
||||
|
||||
if isinstance(accelerator, str):
|
||||
_ACCELERATOR = _ACCELERATOR_MAPPING[accelerator]()
|
||||
elif isinstance(accelerator, BaseAccelerator):
|
||||
_ACCELERATOR = accelerator
|
||||
else:
|
||||
raise TypeError("accelerator must be either a string or an instance of BaseAccelerator")
|
||||
|
||||
|
||||
def auto_set_accelerator() -> None:
|
||||
"""
|
||||
Automatically check if any accelerator is available.
|
||||
If an accelerator is availabe, set it as the global accelerator.
|
||||
"""
|
||||
global _ACCELERATOR
|
||||
|
||||
for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items():
|
||||
try:
|
||||
accelerator = accelerator_cls()
|
||||
if accelerator_name == "cpu" or accelerator.is_available():
|
||||
_ACCELERATOR = accelerator
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
if _ACCELERATOR is None:
|
||||
raise RuntimeError("No accelerator is available.")
|
||||
|
||||
|
||||
def get_accelerator() -> BaseAccelerator:
|
||||
"""
|
||||
Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized
|
||||
to the default accelerator type.
|
||||
|
||||
Returns: the accelerator for the current process.
|
||||
"""
|
||||
global _ACCELERATOR
|
||||
|
||||
if _ACCELERATOR is None:
|
||||
auto_set_accelerator()
|
||||
return _ACCELERATOR
|
|
@ -0,0 +1,320 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["BaseAccelerator"]
|
||||
|
||||
|
||||
class BaseAccelerator(ABC):
|
||||
support_set_device: bool = True
|
||||
|
||||
def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None:
|
||||
self._name = name
|
||||
self._communication_backend = communication_backend
|
||||
self._is_synchronous = is_synchronous
|
||||
|
||||
# =======================
|
||||
# immutable attributes
|
||||
# =======================
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Return the name of the accelerator.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def communication_backend(self) -> str:
|
||||
"""
|
||||
Return the name of the backend communication library.
|
||||
"""
|
||||
return self._communication_backend
|
||||
|
||||
@property
|
||||
def is_synchronous(self) -> bool:
|
||||
"""
|
||||
Return whether the accelerator is a synchronous device.
|
||||
"""
|
||||
return self._is_synchronous
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls_name = self.__class__.__name__
|
||||
return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})"
|
||||
|
||||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
@abstractmethod
|
||||
def get_version(self) -> str:
|
||||
"""
|
||||
Return the version of the accelerator which torch is built against.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_current_device(self) -> torch.device:
|
||||
"""
|
||||
Return the current device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
"""
|
||||
Return the name of the device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def synchronize(self, device: Union[torch.device, int] = None):
|
||||
"""
|
||||
Synchronize the current process.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self):
|
||||
"""
|
||||
Check if the accelerator is available.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def device_count(self):
|
||||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
||||
|
||||
def set_to_device(self, models: Any) -> Any:
|
||||
"""
|
||||
Send model to device.
|
||||
|
||||
:param models: nn.module or a list of module
|
||||
"""
|
||||
if isinstance(models, list) and len(models) > 1:
|
||||
ret = []
|
||||
for model in models:
|
||||
ret.append(model.to(self.get_current_device()))
|
||||
return ret
|
||||
elif isinstance(models, list):
|
||||
return models[0].to(self.get_current_device())
|
||||
else:
|
||||
return models.to(self.get_current_device())
|
||||
|
||||
@abstractmethod
|
||||
def get_device_capability(self, device=None) -> Tuple[int, int]:
|
||||
"""
|
||||
Gets the capability of a device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_device_name(self, device=None) -> str:
|
||||
"""
|
||||
Gets the name of a device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_device_properties(self, device):
|
||||
"""
|
||||
Gets the properties of a device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def utilization(self, device=None) -> int:
|
||||
"""
|
||||
Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc.
|
||||
"""
|
||||
|
||||
# =======================
|
||||
# random number generator APIs
|
||||
# =======================
|
||||
@abstractmethod
|
||||
def get_rng_state(self, device="cuda") -> torch.Tensor:
|
||||
"""
|
||||
Returns the random number generator state of the specified device as a ByteTensor.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_rng_state_all(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of ByteTensor representing the random number states of all devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
|
||||
"""
|
||||
Sets the random number generator state of the specified device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
|
||||
"""
|
||||
Sets the random number generator state of all devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def manual_seed(self, seed: int) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers for the current device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def manual_seed_all(self, seed: int) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers on all devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def seed(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number for the current device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def seed_all(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number on all devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def initial_seed(self) -> int:
|
||||
"""
|
||||
Returns the current random seed of the current device.
|
||||
"""
|
||||
|
||||
# =======================
|
||||
# memory management APIs
|
||||
# =======================
|
||||
@abstractmethod
|
||||
def empty_cache(self) -> None:
|
||||
"""
|
||||
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def memory_stats(self, device=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of CUDA memory allocator statistics for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def memory_summary(self, device=None, abbreviated=False) -> str:
|
||||
"""
|
||||
Returns a human-readable printout of the current memory allocator statistics for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def memory_snapshot(self):
|
||||
"""
|
||||
Returns a snapshot of the CUDA memory allocator state across all devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current device memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def max_memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum device memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset_max_memory_allocated(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum device memory occupied by tensors for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset_max_memory_cached(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current device memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def max_memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum device memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
|
||||
"""
|
||||
Set memory fraction for a process.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset_peak_memory_stats(self, device=None) -> None:
|
||||
"""
|
||||
Resets the "peak" stats tracked by the device memory allocator.
|
||||
"""
|
||||
|
||||
# =======================
|
||||
# streams and events APIs
|
||||
# =======================
|
||||
|
||||
@abstractmethod
|
||||
def Stream(self, device=None, priority=0, **kwargs):
|
||||
"""
|
||||
A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
|
||||
"""
|
||||
device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def current_stream(self, device=None):
|
||||
"""
|
||||
Returns the currently selected Stream for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def default_stream(self, device=None):
|
||||
"""
|
||||
Returns the default Stream for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_stream(self, stream_):
|
||||
"""
|
||||
Sets the current stream.This is a wrapper API to set the stream.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def stream(self, stream_):
|
||||
"""
|
||||
Wrapper around the Context-manager StreamContext that selects a given stream.
|
||||
"""
|
||||
|
||||
# =======================
|
||||
# amp APIs
|
||||
# =======================
|
||||
@abstractmethod
|
||||
def autocast(
|
||||
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
Return autocast function
|
||||
"""
|
|
@ -0,0 +1,283 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
__all__ = ["CpuAccelerator"]
|
||||
|
||||
|
||||
class CpuAccelerator(BaseAccelerator):
|
||||
support_set_device: bool = False
|
||||
"""
|
||||
Accelerator class for cpu.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False)
|
||||
|
||||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
def get_version(self) -> str:
|
||||
"""
|
||||
Return the version of the accelerator which torch is built against.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def get_current_device(self) -> torch.device:
|
||||
"""
|
||||
Return the current device.
|
||||
"""
|
||||
return torch.device("cpu")
|
||||
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
"""
|
||||
Return the name of the device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def synchronize(self, device: Union[torch.device, int] = None):
|
||||
"""
|
||||
Synchronize the current process.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def is_available(self):
|
||||
"""
|
||||
Check if the accelerator is available.
|
||||
"""
|
||||
return True
|
||||
|
||||
def device_count(self):
|
||||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def get_device_capability(self, device=None) -> Tuple[int, int]:
|
||||
"""
|
||||
Gets the cuda capability of a device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def get_device_name(self, device=None) -> str:
|
||||
"""
|
||||
Gets the name of a device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def get_device_properties(self, device):
|
||||
"""
|
||||
Gets the properties of a device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def utilization(self, device=None) -> int:
|
||||
"""
|
||||
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
# =======================
|
||||
# random number generator APIs
|
||||
# =======================
|
||||
def get_rng_state(self, device=None) -> torch.Tensor:
|
||||
"""
|
||||
Returns the random number generator state of the specified GPU as a ByteTensor.
|
||||
"""
|
||||
return torch.get_rng_state(device)
|
||||
|
||||
def get_rng_state_all(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of ByteTensor representing the random number states of all devices.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None:
|
||||
"""
|
||||
Sets the random number generator state of the specified GPU.
|
||||
"""
|
||||
torch.set_rng_state(new_state)
|
||||
|
||||
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
|
||||
"""
|
||||
Sets the random number generator state of all devices.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def manual_seed(self, seed: int) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers for the current GPU.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def manual_seed_all(self, seed: int) -> None:
|
||||
"""
|
||||
Set the random seed for the all processes.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def seed(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number for the current GPU.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def seed_all(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number on all GPUs.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def initial_seed(self) -> int:
|
||||
"""
|
||||
Returns the current random seed of the current GPU.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
# =======================
|
||||
# memory management APIs
|
||||
# =======================
|
||||
|
||||
def empty_cache(self) -> None:
|
||||
"""
|
||||
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def memory_stats(self, device=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of CUDA memory allocator statistics for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def memory_summary(self, device=None, abbreviated=False) -> str:
|
||||
"""
|
||||
Returns a human-readable printout of the current memory allocator statistics for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def memory_snapshot(self):
|
||||
"""
|
||||
Returns a snapshot of the CUDA memory allocator state across all devices.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def max_memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def reset_max_memory_allocated(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def reset_max_memory_cached(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return psutil.Process().memory_info().rss
|
||||
|
||||
def max_memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
|
||||
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
|
||||
"""
|
||||
Set memory fraction for a process.
|
||||
"""
|
||||
max_memory = int(psutil.virtual_memory().total * fraction)
|
||||
_, hard = resource.getrlimit(resource.RLIMIT_AS)
|
||||
resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard))
|
||||
|
||||
def reset_peak_memory_stats(self, device=None) -> None:
|
||||
"""
|
||||
Resets the "peak" stats tracked by the CUDA memory allocator.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
# =======================
|
||||
# streams and events APIs
|
||||
# =======================
|
||||
|
||||
def Stream(self, device=None, priority=0, **kwargs):
|
||||
"""
|
||||
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
|
||||
"""
|
||||
CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def current_stream(self, device=None):
|
||||
"""
|
||||
Returns the currently selected Stream for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def default_stream(self, device=None):
|
||||
"""
|
||||
Returns the default Stream for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def set_stream(self, stream_):
|
||||
"""
|
||||
Sets the current stream.This is a wrapper API to set the stream.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def stream(self, stream_):
|
||||
"""
|
||||
Wrapper around the Context-manager StreamContext that selects a given stream.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
# =======================
|
||||
# amp APIs
|
||||
# =======================
|
||||
def autocast(
|
||||
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
Return autocast function
|
||||
"""
|
||||
return nullcontext
|
|
@ -0,0 +1,282 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
__all__ = ["CudaAccelerator"]
|
||||
|
||||
|
||||
class CudaAccelerator(BaseAccelerator):
|
||||
"""
|
||||
Accelerator class for Nvidia CUDA devices.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False)
|
||||
|
||||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
def get_version(self) -> str:
|
||||
"""
|
||||
Return the version of the accelerator which torch is built against.
|
||||
"""
|
||||
return torch.version.cuda
|
||||
|
||||
def get_current_device(self) -> torch.device:
|
||||
"""
|
||||
Return the current device.
|
||||
"""
|
||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
if device is None:
|
||||
if not dist.is_initialized():
|
||||
raise RuntimeError("Cannot get current device when distributed is not initialized.")
|
||||
device = dist.get_rank() % self.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
"""
|
||||
Return the name of the device.
|
||||
"""
|
||||
return torch.cuda.get_device_name(device)
|
||||
|
||||
def synchronize(self, device: Union[torch.device, int] = None):
|
||||
"""
|
||||
Synchronize the current process.
|
||||
"""
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
def is_available(self):
|
||||
"""
|
||||
Check if the accelerator is available.
|
||||
"""
|
||||
return torch.cuda.is_available()
|
||||
|
||||
def device_count(self):
|
||||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
||||
return torch.cuda.device_count()
|
||||
|
||||
def get_device_capability(self, device=None) -> Tuple[int, int]:
|
||||
"""
|
||||
Gets the cuda capability of a device.
|
||||
"""
|
||||
return torch.cuda.get_device_capability(device)
|
||||
|
||||
def get_device_name(self, device=None) -> str:
|
||||
"""
|
||||
Gets the name of a device.
|
||||
"""
|
||||
return torch.cuda.get_device_name(device)
|
||||
|
||||
def get_device_properties(self, device):
|
||||
"""
|
||||
Gets the properties of a device.
|
||||
"""
|
||||
return torch.cuda.get_device_properties(device)
|
||||
|
||||
def utilization(self, device=None) -> int:
|
||||
"""
|
||||
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
|
||||
"""
|
||||
return torch.cuda.utilization(device)
|
||||
|
||||
# =======================
|
||||
# random number generator APIs
|
||||
# =======================
|
||||
def get_rng_state(self, device="cuda") -> torch.Tensor:
|
||||
"""
|
||||
Returns the random number generator state of the specified GPU as a ByteTensor.
|
||||
"""
|
||||
return torch.cuda.get_rng_state(device)
|
||||
|
||||
def get_rng_state_all(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of ByteTensor representing the random number states of all devices.
|
||||
"""
|
||||
return torch.cuda.get_rng_state_all()
|
||||
|
||||
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
|
||||
"""
|
||||
Sets the random number generator state of the specified GPU.
|
||||
"""
|
||||
torch.cuda.set_rng_state(new_state, device)
|
||||
|
||||
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
|
||||
"""
|
||||
Sets the random number generator state of all devices.
|
||||
"""
|
||||
torch.cuda.set_rng_state_all(new_states)
|
||||
|
||||
def manual_seed(self, seed: int) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers for the current GPU.
|
||||
"""
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
def manual_seed_all(self, seed: int) -> None:
|
||||
"""
|
||||
Set the random seed for the all processes.
|
||||
"""
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def seed(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number for the current GPU.
|
||||
"""
|
||||
torch.cuda.seed()
|
||||
|
||||
def seed_all(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number on all GPUs.
|
||||
"""
|
||||
torch.cuda.seed_all()
|
||||
|
||||
def initial_seed(self) -> int:
|
||||
"""
|
||||
Returns the current random seed of the current GPU.
|
||||
"""
|
||||
return torch.cuda.initial_seed()
|
||||
|
||||
# =======================
|
||||
# memory management APIs
|
||||
# =======================
|
||||
|
||||
def empty_cache(self) -> None:
|
||||
"""
|
||||
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def memory_stats(self, device=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of CUDA memory allocator statistics for a given device.
|
||||
"""
|
||||
return torch.cuda.memory_stats(device=device)
|
||||
|
||||
def memory_summary(self, device=None, abbreviated=False) -> str:
|
||||
"""
|
||||
Returns a human-readable printout of the current memory allocator statistics for a given device.
|
||||
"""
|
||||
return torch.cuda.memory_summary(device=device, abbreviated=abbreviated)
|
||||
|
||||
def memory_snapshot(self):
|
||||
"""
|
||||
Returns a snapshot of the CUDA memory allocator state across all devices.
|
||||
"""
|
||||
return torch.cuda.memory_snapshot()
|
||||
|
||||
def memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
return torch.cuda.memory_allocated(device=device)
|
||||
|
||||
def max_memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
return torch.cuda.max_memory_allocated(device=device)
|
||||
|
||||
def reset_max_memory_allocated(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
|
||||
"""
|
||||
torch.cuda.reset_max_memory_allocated(device=device)
|
||||
|
||||
def reset_max_memory_cached(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
|
||||
"""
|
||||
torch.cuda.reset_max_memory_cached(device=device)
|
||||
|
||||
def memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return torch.cuda.memory_reserved(device=device)
|
||||
|
||||
def max_memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return torch.cuda.max_memory_reserved(device=device)
|
||||
|
||||
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
|
||||
"""
|
||||
Set memory fraction for a process.
|
||||
"""
|
||||
torch.cuda.set_per_process_memory_fraction(fraction, device=device)
|
||||
|
||||
def reset_peak_memory_stats(self, device=None) -> None:
|
||||
"""
|
||||
Resets the "peak" stats tracked by the CUDA memory allocator.
|
||||
"""
|
||||
torch.cuda.reset_peak_memory_stats(device=device)
|
||||
|
||||
# =======================
|
||||
# streams and events APIs
|
||||
# =======================
|
||||
|
||||
def Stream(self, device=None, priority=0, **kwargs):
|
||||
"""
|
||||
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
|
||||
"""
|
||||
return torch.cuda.Stream(device, priority, **kwargs)
|
||||
|
||||
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
|
||||
"""
|
||||
CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
|
||||
"""
|
||||
return torch.cuda.Event(enable_timing, blocking, interprocess)
|
||||
|
||||
def current_stream(self, device=None):
|
||||
"""
|
||||
Returns the currently selected Stream for a given device.
|
||||
"""
|
||||
return torch.cuda.current_stream(device)
|
||||
|
||||
def default_stream(self, device=None):
|
||||
"""
|
||||
Returns the default Stream for a given device.
|
||||
"""
|
||||
return torch.cuda.default_stream(device)
|
||||
|
||||
def set_stream(self, stream_):
|
||||
"""
|
||||
Sets the current stream.This is a wrapper API to set the stream.
|
||||
"""
|
||||
torch.cuda.set_stream(stream_)
|
||||
|
||||
def stream(self, stream_):
|
||||
"""
|
||||
Wrapper around the Context-manager StreamContext that selects a given stream.
|
||||
"""
|
||||
return torch.cuda.stream(stream_)
|
||||
|
||||
# =======================
|
||||
# amp APIs
|
||||
# =======================
|
||||
def autocast(
|
||||
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
Return autocast function
|
||||
"""
|
||||
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
|
|
@ -0,0 +1,288 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
try:
|
||||
import torch_npu # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["NpuAccelerator"]
|
||||
|
||||
|
||||
class NpuAccelerator(BaseAccelerator):
|
||||
"""
|
||||
Accelerator class for Huawei NPU devices.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="npu", communication_backend="hccl", is_synchronous=False)
|
||||
|
||||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
def get_version(self) -> str:
|
||||
"""
|
||||
Return the version of the accelerator which torch is built against.
|
||||
"""
|
||||
return torch.version.cann
|
||||
|
||||
def get_current_device(self) -> torch.device:
|
||||
"""
|
||||
Return the current device.
|
||||
"""
|
||||
return torch.device(f"npu:{torch.npu.current_device()}")
|
||||
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
return torch.npu.current_device()
|
||||
|
||||
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
if device is None:
|
||||
if not dist.is_initialized():
|
||||
raise RuntimeError("Cannot get current device when distributed is not initialized.")
|
||||
device = dist.get_rank() % self.device_count()
|
||||
torch.npu.set_device(device)
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
"""
|
||||
Return the name of the device.
|
||||
"""
|
||||
return torch.npu.get_device_name(device)
|
||||
|
||||
def synchronize(self, device: Union[torch.device, int] = None):
|
||||
"""
|
||||
Synchronize the current process.
|
||||
"""
|
||||
torch.npu.synchronize(device)
|
||||
|
||||
def is_available(self):
|
||||
"""
|
||||
Check if the accelerator is available.
|
||||
"""
|
||||
return torch.npu.is_available()
|
||||
|
||||
def device_count(self):
|
||||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
||||
return torch.npu.device_count()
|
||||
|
||||
def get_device_capability(self, device=None) -> Tuple[int, int]:
|
||||
"""
|
||||
Gets the npu capability of a device.
|
||||
"""
|
||||
return torch.npu.get_device_capability(device)
|
||||
|
||||
def get_device_name(self, device=None) -> str:
|
||||
"""
|
||||
Gets the name of a device.
|
||||
"""
|
||||
return torch.npu.get_device_name(device)
|
||||
|
||||
def get_device_properties(self, device):
|
||||
"""
|
||||
Gets the properties of a device.
|
||||
"""
|
||||
return torch.npu.get_device_properties(device)
|
||||
|
||||
def utilization(self, device=None) -> int:
|
||||
"""
|
||||
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
|
||||
"""
|
||||
return torch.npu.utilization(device)
|
||||
|
||||
# =======================
|
||||
# random number generator APIs
|
||||
# =======================
|
||||
def get_rng_state(self, device="npu") -> torch.Tensor:
|
||||
"""
|
||||
Returns the random number generator state of the specified GPU as a ByteTensor.
|
||||
"""
|
||||
return torch.npu.get_rng_state(device)
|
||||
|
||||
def get_rng_state_all(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of ByteTensor representing the random number states of all devices.
|
||||
"""
|
||||
return torch.npu.get_rng_state_all()
|
||||
|
||||
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None:
|
||||
"""
|
||||
Sets the random number generator state of the specified GPU.
|
||||
"""
|
||||
torch.npu.set_rng_state(new_state, device)
|
||||
|
||||
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
|
||||
"""
|
||||
Sets the random number generator state of all devices.
|
||||
"""
|
||||
torch.npu.set_rng_state_all(new_states)
|
||||
|
||||
def manual_seed(self, seed: int) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers for the current GPU.
|
||||
"""
|
||||
torch.npu.manual_seed(seed)
|
||||
|
||||
def manual_seed_all(self, seed: int) -> None:
|
||||
"""
|
||||
Set the random seed for the all processes.
|
||||
"""
|
||||
torch.npu.manual_seed_all(seed)
|
||||
|
||||
def seed(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number for the current GPU.
|
||||
"""
|
||||
torch.npu.seed()
|
||||
|
||||
def seed_all(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number on all GPUs.
|
||||
"""
|
||||
torch.npu.seed_all()
|
||||
|
||||
def initial_seed(self) -> int:
|
||||
"""
|
||||
Returns the current random seed of the current GPU.
|
||||
"""
|
||||
return torch.npu.initial_seed()
|
||||
|
||||
# =======================
|
||||
# memory management APIs
|
||||
# =======================
|
||||
|
||||
def empty_cache(self) -> None:
|
||||
"""
|
||||
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
|
||||
"""
|
||||
torch.npu.empty_cache()
|
||||
|
||||
def memory_stats(self, device=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of npu memory allocator statistics for a given device.
|
||||
"""
|
||||
return torch.npu.memory_stats(device=device)
|
||||
|
||||
def memory_summary(self, device=None, abbreviated=False) -> str:
|
||||
"""
|
||||
Returns a human-readable printout of the current memory allocator statistics for a given device.
|
||||
"""
|
||||
return torch.npu.memory_summary(device=device, abbreviated=abbreviated)
|
||||
|
||||
def memory_snapshot(self):
|
||||
"""
|
||||
Returns a snapshot of the npu memory allocator state across all devices.
|
||||
"""
|
||||
return torch.npu.memory_snapshot()
|
||||
|
||||
def memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
return torch.npu.memory_allocated(device=device)
|
||||
|
||||
def max_memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
return torch.npu.max_memory_allocated(device=device)
|
||||
|
||||
def reset_max_memory_allocated(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
|
||||
"""
|
||||
torch.npu.reset_max_memory_allocated(device=device)
|
||||
|
||||
def reset_max_memory_cached(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
|
||||
"""
|
||||
torch.npu.reset_max_memory_cached(device=device)
|
||||
|
||||
def memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return torch.npu.memory_reserved(device=device)
|
||||
|
||||
def max_memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return torch.npu.max_memory_reserved(device=device)
|
||||
|
||||
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
|
||||
"""
|
||||
Set memory fraction for a process.
|
||||
"""
|
||||
torch.npu.set_per_process_memory_fraction(fraction, device=device)
|
||||
|
||||
def reset_peak_memory_stats(self, device=None) -> None:
|
||||
"""
|
||||
Resets the "peak" stats tracked by the npu memory allocator.
|
||||
"""
|
||||
torch.npu.reset_peak_memory_stats(device=device)
|
||||
|
||||
# =======================
|
||||
# streams and events APIs
|
||||
# =======================
|
||||
|
||||
def Stream(self, device=None, priority=0, **kwargs):
|
||||
"""
|
||||
A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details.
|
||||
"""
|
||||
return torch.npu.Stream(device, priority, **kwargs)
|
||||
|
||||
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
|
||||
"""
|
||||
npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams.
|
||||
"""
|
||||
return torch.npu.Event(enable_timing, blocking, interprocess)
|
||||
|
||||
def current_stream(self, device=None):
|
||||
"""
|
||||
Returns the currently selected Stream for a given device.
|
||||
"""
|
||||
return torch.npu.current_stream(device)
|
||||
|
||||
def default_stream(self, device=None):
|
||||
"""
|
||||
Returns the default Stream for a given device.
|
||||
"""
|
||||
return torch.npu.default_stream(device)
|
||||
|
||||
def set_stream(self, stream_):
|
||||
"""
|
||||
Sets the current stream.This is a wrapper API to set the stream.
|
||||
"""
|
||||
torch.npu.set_stream(stream_)
|
||||
|
||||
def stream(self, stream_):
|
||||
"""
|
||||
Wrapper around the Context-manager StreamContext that selects a given stream.
|
||||
"""
|
||||
return torch.npu.stream(stream_)
|
||||
|
||||
# =======================
|
||||
# amp APIs
|
||||
# =======================
|
||||
def autocast(
|
||||
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
Return autocast function
|
||||
"""
|
||||
return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
|
|
@ -7,8 +7,8 @@ from typing import Dict
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
__all__ = ["BaseGradScaler"]
|
||||
|
||||
|
@ -23,7 +23,7 @@ class BaseGradScaler(ABC):
|
|||
|
||||
def __init__(self, initial_scale: float, verbose: bool):
|
||||
assert initial_scale > 0
|
||||
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
|
||||
self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float)
|
||||
self._verbose = verbose
|
||||
|
||||
if self._verbose:
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.utils.device import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .base_grad_scaler import BaseGradScaler
|
||||
|
||||
|
@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler):
|
|||
hysteresis: int = 2,
|
||||
verbose: bool = False,
|
||||
):
|
||||
a = get_accelerator()
|
||||
a.device_count()
|
||||
super().__init__(initial_scale, verbose)
|
||||
if min_scale:
|
||||
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
|
||||
self._min_scale = torch.tensor(
|
||||
[min_scale], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
)
|
||||
else:
|
||||
self._min_scale = None
|
||||
|
||||
if max_scale:
|
||||
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
|
||||
self._max_scale = torch.tensor(
|
||||
[max_scale], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
)
|
||||
else:
|
||||
self._max_scale = None
|
||||
|
||||
|
@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler):
|
|||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._scale = state_dict["scale"].to(get_current_device())
|
||||
self._scale = state_dict["scale"].to(get_accelerator().get_current_device())
|
||||
self._growth_factor = state_dict["growth_factor"]
|
||||
self._backoff_factor = state_dict["backoff_factor"]
|
||||
self._hysteresis = state_dict["hysteresis"]
|
||||
|
|
|
@ -5,8 +5,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import MixedPrecisionMixin
|
||||
|
||||
|
@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
|
|||
max_scale=max_scale,
|
||||
)
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
||||
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
|
||||
|
||||
@property
|
||||
def loss_scale(self) -> float:
|
||||
|
|
|
@ -4,10 +4,10 @@ from typing import Dict, Tuple
|
|||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base_offload_module import BaseOffloadModule
|
||||
from .region import Region
|
||||
|
@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper):
|
|||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
self._found_overflow: torch.Tensor = torch.zeros(
|
||||
1, dtype=torch.int64, device=get_accelerator().get_current_device()
|
||||
)
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def _set_grad_ptr(self):
|
||||
|
|
|
@ -11,7 +11,7 @@ except:
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.utils.device import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .region import Region
|
||||
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
|
||||
|
@ -57,7 +57,10 @@ class Solver(ABC):
|
|||
if memory_budget > 0:
|
||||
self.memory_budget = memory_budget * self.error_factor
|
||||
else:
|
||||
self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
|
||||
self.memory_budget = (
|
||||
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
|
||||
* self.error_factor
|
||||
)
|
||||
|
||||
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
|
||||
self.comp_power: float = self._extract_computing_power()
|
||||
|
|
|
@ -5,8 +5,8 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils.device import autocast
|
||||
|
||||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
|
@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
|
|||
super().__init__(module)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with autocast():
|
||||
with get_accelerator().autocast():
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_model_base_filenames,
|
||||
|
@ -27,8 +28,6 @@ from colossalai.checkpoint_io.utils import (
|
|||
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.device import IS_NPU_AVAILABLE
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
|
@ -366,11 +365,11 @@ class GeminiPlugin(DPPluginBase):
|
|||
) -> None:
|
||||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||
if IS_NPU_AVAILABLE:
|
||||
if get_accelerator().name == "npu":
|
||||
assert placement_policy == "static", "NPU only supports static placement policy"
|
||||
self.gemini_config = dict(
|
||||
chunk_config_dict=chunk_config_dict,
|
||||
chunk_init_device=(chunk_init_device or get_current_device()),
|
||||
chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
|
||||
placement_policy=placement_policy,
|
||||
enable_gradient_accumulation=enable_gradient_accumulation,
|
||||
shard_param_frac=shard_param_frac,
|
||||
|
@ -486,7 +485,10 @@ class GeminiPlugin(DPPluginBase):
|
|||
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
|
||||
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
||||
sampler = DistributedSampler(
|
||||
dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle
|
||||
dataset,
|
||||
num_replicas=zero_world_size * extra_dp_world_size,
|
||||
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
|
|
|
@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
|
@ -28,7 +29,6 @@ from colossalai.shardformer import ShardConfig, ShardFormer
|
|||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
from colossalai.utils.device import get_current_device
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
self.mixed_precision = torch.bfloat16
|
||||
if self.mixed_precision is not None:
|
||||
module = module.to(self.mixed_precision)
|
||||
module = module.to(get_current_device())
|
||||
module = module.to(get_accelerator().get_current_device())
|
||||
|
||||
# setting input type cast when using mixed precision
|
||||
self.convert_fn = None
|
||||
|
@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
|
||||
if grads is not None:
|
||||
# Synchronize provided gradient tensors across the tensor parallelism group.
|
||||
|
@ -346,7 +345,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
total_norm_cuda = torch.tensor(
|
||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if self.pp_size > 1:
|
||||
|
@ -386,7 +387,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.tensor(
|
||||
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
|
||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
|
@ -487,7 +488,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, *args, **kwargs)
|
||||
|
||||
|
@ -513,7 +513,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
|
@ -545,7 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
# so we need to calculate the norm of 'tp' and 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
|
||||
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
total_norm_cuda = torch.tensor(
|
||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
|
@ -589,7 +590,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.tensor(
|
||||
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
|
||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
|
@ -674,7 +675,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass `_sync_grad` method to synchronize gradients.
|
||||
super()._sync_grad()
|
||||
|
||||
|
@ -802,7 +802,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# so we only need to calculate the norm 'tp' of 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(gradients, norm_type)
|
||||
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
total_norm_cuda = torch.tensor(
|
||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
|
||||
if tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
|
@ -842,7 +844,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.tensor(
|
||||
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
|
||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
if dp_size > 1:
|
||||
# compute norm in dp process group
|
||||
|
@ -1081,7 +1083,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return True
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
@ -1175,9 +1177,14 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
||||
# run with gradients accumulation
|
||||
if model.require_grad_sync == False or (
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
||||
):
|
||||
return outputs
|
||||
|
||||
# Synchronize the grads of shared parameters of the model.
|
||||
model.sync_shared_params()
|
||||
|
||||
# Synchronize sequence parallelism gradients of the model.
|
||||
model.sync_sp_grads()
|
||||
|
||||
|
@ -1241,5 +1248,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
|
||||
def no_sync(self, model: Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert (
|
||||
self.zero_stage != 2
|
||||
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
|
||||
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_optimizer_base_filenames,
|
||||
|
@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import (
|
|||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||
self.dtype = torch.bfloat16
|
||||
if self.dtype is not None:
|
||||
module = module.to(self.dtype)
|
||||
module = module.to(get_current_device())
|
||||
module = module.to(get_accelerator().get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
|
|
|
@ -6,12 +6,12 @@ import warnings
|
|||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.context import Config
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
|
||||
from colossalai.utils import set_seed
|
||||
|
||||
|
||||
def launch(
|
||||
|
@ -47,17 +47,18 @@ def launch(
|
|||
if rank == 0:
|
||||
warnings.warn("`config` is deprecated and will be removed soon.")
|
||||
|
||||
if IS_NPU_AVAILABLE and backend == "nccl":
|
||||
backend = "hccl"
|
||||
cur_accelerator = get_accelerator()
|
||||
|
||||
backend = cur_accelerator.communication_backend
|
||||
|
||||
# init default process group
|
||||
init_method = f"tcp://[{host}]:{port}"
|
||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||
|
||||
# set cuda device
|
||||
if torch.cuda.is_available() or IS_NPU_AVAILABLE:
|
||||
# if local rank is not given, calculate automatically
|
||||
set_device(local_rank)
|
||||
if cur_accelerator.support_set_device:
|
||||
cur_accelerator.set_device(local_rank)
|
||||
|
||||
set_seed(seed)
|
||||
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
|
||||
|
||||
__all__ = [
|
||||
"LayerNorm",
|
||||
"FusedScaleMaskSoftmax",
|
||||
"MultiHeadAttention",
|
||||
]
|
|
@ -1,63 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "column_remap.cuh"
|
||||
#include "util.cuh"
|
||||
|
||||
const int SHUF_BLOCKSIZE_X = 256;
|
||||
const int SHUF_BLOCKSIZE_Y = 16;
|
||||
|
||||
__global__ void column_remap_kernel
|
||||
(
|
||||
const half* __restrict__ x,
|
||||
half* __restrict__ x_new,
|
||||
const int x_width,
|
||||
const int x_height,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
|
||||
if (x_column >= x_width) return;
|
||||
//if (x_row >= x_height) return;
|
||||
|
||||
int x_stride = x_width;
|
||||
int x_idx = x_row * x_stride + x_column;
|
||||
|
||||
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
|
||||
int x_idx_end = x_row_end * x_stride + x_column;
|
||||
|
||||
int s_column = x_map[x_column];
|
||||
int s_idx = x_row * x_stride + s_column;
|
||||
|
||||
while (x_idx < x_idx_end)
|
||||
{
|
||||
x_new[x_idx] = x[s_idx];
|
||||
x_idx += x_stride;
|
||||
s_idx += x_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Remap columns in x to correspond to sequential group index before matmul
|
||||
//
|
||||
// perform x -> seq_x such that seq_x @ seq_w == x @ w
|
||||
|
||||
void column_remap_cuda
|
||||
(
|
||||
const half* x,
|
||||
half* x_new,
|
||||
const int x_height,
|
||||
const int x_width,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
|
||||
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
|
||||
1
|
||||
);
|
||||
|
||||
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _column_remap_cuh
|
||||
#define _column_remap_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
|
||||
void column_remap_cuda
|
||||
(
|
||||
const half* x,
|
||||
half* x_new,
|
||||
const int x_height,
|
||||
const int x_width,
|
||||
const uint32_t* x_map
|
||||
);
|
||||
|
||||
#endif
|
|
@ -1,58 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _cuda_compat_cuh
|
||||
#define _cuda_compat_cuh
|
||||
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -1,75 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#define _cuda_buffers_cu
|
||||
#include "cuda_buffers.cuh"
|
||||
|
||||
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
|
||||
// __constant__ half2 q4_table[16][256];
|
||||
// half2 q4_table_host[16][256];
|
||||
// bool q4_table_init = false;
|
||||
|
||||
CudaBuffers::CudaBuffers
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
) :
|
||||
device(_device),
|
||||
temp_state_size(_temp_state_size),
|
||||
temp_state(_temp_state),
|
||||
temp_dq(_temp_dq)
|
||||
{
|
||||
cudaSetDevice(_device);
|
||||
|
||||
cudaStreamCreate(&alt_stream_1);
|
||||
cudaStreamCreate(&alt_stream_2);
|
||||
cudaStreamCreate(&alt_stream_3);
|
||||
cudaEventCreate(&alt_stream_1_done);
|
||||
cudaEventCreate(&alt_stream_2_done);
|
||||
cudaEventCreate(&alt_stream_3_done);
|
||||
}
|
||||
|
||||
CudaBuffers::~CudaBuffers()
|
||||
{
|
||||
cudaStreamDestroy(alt_stream_1);
|
||||
cudaStreamDestroy(alt_stream_2);
|
||||
cudaStreamDestroy(alt_stream_3);
|
||||
cudaEventDestroy(alt_stream_1_done);
|
||||
cudaEventDestroy(alt_stream_2_done);
|
||||
cudaEventDestroy(alt_stream_3_done);
|
||||
}
|
||||
|
||||
CudaBuffers* get_buffers(const int device_index)
|
||||
{
|
||||
return g_buffers[device_index];
|
||||
}
|
||||
|
||||
void prepare_buffers_cuda
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
)
|
||||
{
|
||||
CudaBuffers* buffers = new CudaBuffers
|
||||
(
|
||||
_device,
|
||||
_temp_state_size,
|
||||
_temp_state,
|
||||
_temp_dq
|
||||
);
|
||||
|
||||
g_buffers[_device] = buffers;
|
||||
}
|
||||
|
||||
void cleanup_buffers_cuda()
|
||||
{
|
||||
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
|
||||
{
|
||||
if (!g_buffers[i]) continue;
|
||||
delete g_buffers[i];
|
||||
g_buffers[i] = NULL;
|
||||
}
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _cuda_buffers_cuh
|
||||
#define _cuda_buffers_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
const int CUDA_MAX_DEVICES = 16;
|
||||
|
||||
// #ifndef _cuda_buffers_cu
|
||||
// extern __constant__ half2 q4_table[16][256];
|
||||
// #endif
|
||||
|
||||
class CudaBuffers
|
||||
{
|
||||
public:
|
||||
int device;
|
||||
|
||||
half* temp_state; // [max_hidden_rows * intermediate_size]
|
||||
int temp_state_size;
|
||||
half* temp_dq; // size of largest quant tensor * 8
|
||||
|
||||
cudaStream_t alt_stream_1;
|
||||
cudaStream_t alt_stream_2;
|
||||
cudaStream_t alt_stream_3;
|
||||
cudaEvent_t alt_stream_1_done;
|
||||
cudaEvent_t alt_stream_2_done;
|
||||
cudaEvent_t alt_stream_3_done;
|
||||
|
||||
CudaBuffers
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
);
|
||||
~CudaBuffers();
|
||||
};
|
||||
|
||||
CudaBuffers* get_buffers(const int device_index);
|
||||
|
||||
void prepare_buffers_cuda
|
||||
(
|
||||
int _device,
|
||||
int _temp_state_size,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
);
|
||||
|
||||
void cleanup_buffers_cuda();
|
||||
|
||||
#endif
|
|
@ -1,49 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _hip_compat_cuh
|
||||
#define _hip_compat_cuh
|
||||
|
||||
// Workaround for a bug in hipamd, backported from upstream.
|
||||
__device__ __forceinline__ __half __compat_hrcp(__half x) {
|
||||
return __half_raw{
|
||||
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
|
||||
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
|
||||
}
|
||||
|
||||
#define hrcp __compat_hrcp
|
||||
#define h2rcp __compat_h2rcp
|
||||
|
||||
// Workaround for hipify_python using rocblas instead of hipblas.
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
|
||||
#define rocblas_handle hipblasHandle_t
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_get_stream hipblasGetStream
|
||||
#define rocblas_set_stream hipblasSetStream
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
|
||||
#endif
|
|
@ -1,254 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include "util.cuh"
|
||||
#include "tuning.h"
|
||||
#include "cuda_buffers.cuh"
|
||||
#include "q4_matrix.cuh"
|
||||
#include "q4_matmul.cuh"
|
||||
#include "column_remap.cuh"
|
||||
|
||||
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
|
||||
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
|
||||
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
|
||||
|
||||
void check_cuda(cudaError_t ret)
|
||||
{
|
||||
switch (ret)
|
||||
{
|
||||
case cudaSuccess:
|
||||
break;
|
||||
|
||||
case cudaUnspecified:
|
||||
printf(" **** Unspecified error\n");
|
||||
TORCH_CHECK(false, "CUDA error");
|
||||
break;
|
||||
|
||||
default:
|
||||
printf(" **** CUDA error\n"); \
|
||||
printf(" **** %s\n", cudaGetErrorString(ret)); \
|
||||
TORCH_CHECK(false, "CUDA error"); \
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Some decluttering macros
|
||||
|
||||
#define STRINGIFY_(__x) #__x
|
||||
#define STRINGIFY(__x) STRINGIFY_(__x)
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
||||
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
|
||||
|
||||
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
||||
do { \
|
||||
TORCH_CHECK(__index >= 0, "no device index"); \
|
||||
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
|
||||
} while(0)
|
||||
|
||||
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
|
||||
do { \
|
||||
TORCH_CHECK_DTYPE(__w, kInt); \
|
||||
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
|
||||
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
|
||||
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
|
||||
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
|
||||
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
|
||||
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
||||
} while(0)
|
||||
|
||||
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
|
||||
{
|
||||
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
||||
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
|
||||
return groupsize;
|
||||
}
|
||||
|
||||
|
||||
// Tuning parameters
|
||||
|
||||
ExLlamaTuning tuningParams;
|
||||
|
||||
void set_tuning_params
|
||||
(
|
||||
int matmul_recons_thd,
|
||||
bool matmul_fused_remap,
|
||||
bool matmul_no_half2
|
||||
)
|
||||
{
|
||||
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
||||
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
||||
tuningParams.matmul_no_half2 = matmul_no_half2;
|
||||
}
|
||||
|
||||
|
||||
// Release all unmanaged objects allocated by the extension
|
||||
|
||||
void cleanup()
|
||||
{
|
||||
cleanup_buffers_cuda();
|
||||
g_q4_free_matrices();
|
||||
}
|
||||
|
||||
|
||||
// Prepare buffers for forward pass
|
||||
|
||||
void prepare_buffers
|
||||
(
|
||||
torch::Device device,
|
||||
torch::Tensor temp_state,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
int device_index = device.index();
|
||||
TORCH_CHECK_DEVICE_INDEX(device_index);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||
|
||||
prepare_buffers_cuda
|
||||
(
|
||||
device_index,
|
||||
// buffer size used for sanity checks
|
||||
temp_state.numel(),
|
||||
(half*) temp_state.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// Create Q4Matrix, return handle
|
||||
|
||||
uintptr_t make_q4
|
||||
(
|
||||
torch::Tensor qweight,
|
||||
torch::Tensor qzeros,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor g_idx,
|
||||
int device
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(qweight, kInt);
|
||||
TORCH_CHECK_DTYPE(qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE(scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
|
||||
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
|
||||
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
|
||||
|
||||
int width = qweight.size(1);
|
||||
int height = qweight.size(0) * 8;
|
||||
int groups = qzeros.size(0);
|
||||
|
||||
Q4Matrix* m = new Q4Matrix
|
||||
(
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
|
||||
(uint32_t*) qweight.data_ptr(),
|
||||
(uint32_t*) qzeros.data_ptr(),
|
||||
(half*) scales.data_ptr(),
|
||||
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
|
||||
|
||||
device
|
||||
);
|
||||
|
||||
g_q4_keep_matrix(m);
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
|
||||
// Matmul half @ quant -> half
|
||||
|
||||
void q4_matmul
|
||||
(
|
||||
torch::Tensor x,
|
||||
uintptr_t w,
|
||||
torch::Tensor out
|
||||
)
|
||||
{
|
||||
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
|
||||
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(out, kHalf);
|
||||
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
|
||||
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
int x_height = x.size(0);
|
||||
|
||||
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
||||
{
|
||||
q4_matmul_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr()
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
q4_matmul_recons_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr(),
|
||||
at::cuda::getCurrentCUDABlasHandle()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Remap columns in half tensor
|
||||
|
||||
void column_remap
|
||||
(
|
||||
torch::Tensor x,
|
||||
torch::Tensor x_new,
|
||||
torch::Tensor x_map
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_new, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_map, kInt);
|
||||
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
|
||||
|
||||
int height = x.size(0);
|
||||
int width = x.size(1);
|
||||
|
||||
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
column_remap_cuda
|
||||
(
|
||||
(half*) x.data_ptr(),
|
||||
(half*) x_new.data_ptr(),
|
||||
height,
|
||||
width,
|
||||
(uint32_t*) x_map.data_ptr()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
||||
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
||||
m.def("cleanup", &cleanup, "cleanup");
|
||||
m.def("make_q4", &make_q4, "make_q4");
|
||||
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
||||
}
|
|
@ -1,294 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _matrix_cuh
|
||||
#define _matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||
};
|
||||
|
||||
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
|
||||
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
|
||||
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
|
||||
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
|
||||
|
||||
half2 tmp = __hmul2(*h_ptr++, v_01);
|
||||
tmp = __hfma2(*h_ptr++, v_23, tmp);
|
||||
tmp = __hfma2(*h_ptr++, v_45, tmp);
|
||||
tmp = __hfma2(*h_ptr++, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(*h_ptr++, v_0);
|
||||
tmp = __hfma(*h_ptr++, v_1, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_2, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_3, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_4, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_5, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_6, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8_x_map
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
half h_0 = h_ptr[*x_map_ptr++];
|
||||
half h_1 = h_ptr[*x_map_ptr++];
|
||||
half h_2 = h_ptr[*x_map_ptr++];
|
||||
half h_3 = h_ptr[*x_map_ptr++];
|
||||
half h_4 = h_ptr[*x_map_ptr++];
|
||||
half h_5 = h_ptr[*x_map_ptr++];
|
||||
half h_6 = h_ptr[*x_map_ptr++];
|
||||
half h_7 = h_ptr[*x_map_ptr++];
|
||||
|
||||
half2 h_01 = __halves2half2(h_0, h_1);
|
||||
half2 h_23 = __halves2half2(h_2, h_3);
|
||||
half2 h_45 = __halves2half2(h_4, h_5);
|
||||
half2 h_67 = __halves2half2(h_6, h_7);
|
||||
|
||||
half2 tmp = __hmul2(h_01, v_01);
|
||||
tmp = __hfma2(h_23, v_23, tmp);
|
||||
tmp = __hfma2(h_45, v_45, tmp);
|
||||
tmp = __hfma2(h_67, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_x_map_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif
|
|
@ -1,260 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "q4_matmul.cuh"
|
||||
#include "column_remap.cuh"
|
||||
#include "util.cuh"
|
||||
#include "matrix.cuh"
|
||||
#include "cu_compat.cuh"
|
||||
#include "cuda_buffers.cuh"
|
||||
#if defined(USE_ROCM)
|
||||
#include "hip_compat.cuh"
|
||||
#endif
|
||||
|
||||
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||
|
||||
typedef void (*fp_q4_matmul_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
half*,
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint32_t*,
|
||||
bool
|
||||
);
|
||||
|
||||
template<bool use_half2, bool use_groupsize, bool use_x_map>
|
||||
__global__ void q4_matmul_kernel
|
||||
(
|
||||
const half* __restrict__ x,
|
||||
const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out,
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int height,
|
||||
const int dim,
|
||||
const int width,
|
||||
const int groupsize,
|
||||
const int block_size_z,
|
||||
const uint32_t* __restrict__ x_map,
|
||||
bool no_zero
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int x_column = block_size_z * blockIdx.z;
|
||||
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
|
||||
|
||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
|
||||
|
||||
int iterations = (x_column_end - x_column) / 8;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_half x_(x, height, dim);
|
||||
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
|
||||
MatrixView_q4_column w_(w, dim, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
|
||||
// Zero output
|
||||
|
||||
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Loop over part of x row (and w column)
|
||||
|
||||
half2 acc = {};
|
||||
half acc_h = {};
|
||||
|
||||
if constexpr (use_groupsize)
|
||||
{
|
||||
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
|
||||
// could be slightly faster
|
||||
|
||||
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
|
||||
{
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
else
|
||||
{
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
|
||||
|
||||
for (int k = x_column; k < x_column + iterations * 8; k += 8)
|
||||
{
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add to block result
|
||||
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half result = __hadd(__low2half(acc), __high2half(acc));
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
||||
}
|
||||
else
|
||||
{
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
|
||||
}
|
||||
}
|
||||
|
||||
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
|
||||
{
|
||||
// <bool use_half2, bool use_groupsize, bool use_x_map>
|
||||
if (tuningParams->matmul_no_half2) {
|
||||
if (block_size_z % groupsize == 0) {
|
||||
if (x_map) return q4_matmul_kernel<false, true, true >;
|
||||
else return q4_matmul_kernel<false, true, false>;
|
||||
} else {
|
||||
if (x_map) return q4_matmul_kernel<false, false, true >;
|
||||
else return q4_matmul_kernel<false, false, false>;
|
||||
}
|
||||
} else {
|
||||
if (block_size_z % groupsize == 0)
|
||||
{
|
||||
if (x_map) return q4_matmul_kernel<true, true, true >;
|
||||
else return q4_matmul_kernel<true, true, false>;
|
||||
} else {
|
||||
if (x_map) return q4_matmul_kernel<true, false, true >;
|
||||
else return q4_matmul_kernel<true, false, false>;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Compute y = x @ w
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
const Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero,
|
||||
cudaStream_t alt_stream
|
||||
)
|
||||
{
|
||||
int height = x_height;
|
||||
int dim = w->height;
|
||||
int width = w->width;
|
||||
|
||||
cudaSetDevice(w->device);
|
||||
|
||||
uint32_t* x_map = w->cuda_x_map;
|
||||
const half* x_mapped = x;
|
||||
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
|
||||
{
|
||||
CudaBuffers* buffers = get_buffers(w->device);
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
x_map = NULL;
|
||||
}
|
||||
|
||||
int block_size_z;
|
||||
if (w->width == 4096) block_size_z = 384; // 7B
|
||||
else if (w->width == 11008) block_size_z = 256;
|
||||
else if (w->width == 5120) block_size_z = 384; // 13B
|
||||
else if (w->width == 13824) block_size_z = 256;
|
||||
else if (w->width == 6656) block_size_z = 256; // 33B
|
||||
else if (w->width == 17920) block_size_z = 128;
|
||||
else block_size_z = 256;
|
||||
|
||||
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
|
||||
|
||||
dim3 threads(THREADS_X, THREADS_Y, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(width + threads.x - 1) / threads.x,
|
||||
(height + threads.y - 1) / threads.y,
|
||||
(dim + block_size_z - 1) / block_size_z
|
||||
);
|
||||
|
||||
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
|
||||
|
||||
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
||||
}
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
Q4Matrix* w,
|
||||
half* out,
|
||||
const cublasHandle_t handle,
|
||||
bool no_zero
|
||||
)
|
||||
{
|
||||
int height = x_height;
|
||||
int dim = w->height;
|
||||
int width = w->width;
|
||||
|
||||
cudaSetDevice(w->device);
|
||||
CudaBuffers* buffers = get_buffers(w->device);
|
||||
|
||||
const half* x_mapped = x;
|
||||
if (w->cuda_x_map)
|
||||
{
|
||||
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
}
|
||||
|
||||
w->reconstruct(buffers->temp_dq);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
|
||||
const float alpha = 1.0f;
|
||||
const float beta = no_zero ? 1.0f : 0.0f;
|
||||
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
|
||||
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
|
||||
#else
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
|
||||
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
|
||||
#endif
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _q4_matmul_cuh
|
||||
#define _q4_matmul_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "q4_matrix.cuh"
|
||||
#include "tuning.h"
|
||||
|
||||
// Workaround for hipify_python using rocblas instead of hipblas.
|
||||
#if defined(USE_ROCM)
|
||||
#include <hipblas/hipblas.h>
|
||||
#define rocblas_handle hipblasHandle_t
|
||||
#endif
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
const Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero = false,
|
||||
cudaStream_t alt_stream = NULL
|
||||
);
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
Q4Matrix* w,
|
||||
half* out,
|
||||
const cublasHandle_t handle,
|
||||
bool no_zero = false
|
||||
);
|
||||
|
||||
#endif
|
|
@ -1,225 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "q4_matrix.cuh"
|
||||
#include <vector>
|
||||
#include "util.cuh"
|
||||
#include "matrix.cuh"
|
||||
|
||||
using namespace std;
|
||||
|
||||
const int UNSHUF_BLOCKSIZE_X = 64;
|
||||
|
||||
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
|
||||
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
|
||||
|
||||
vector<Q4Matrix*> g_q4_matrices;
|
||||
|
||||
void g_q4_keep_matrix(Q4Matrix* m)
|
||||
{
|
||||
g_q4_matrices.push_back(m);
|
||||
}
|
||||
|
||||
void g_q4_free_matrices()
|
||||
{
|
||||
for (const auto& m : g_q4_matrices) delete m;
|
||||
g_q4_matrices.clear();
|
||||
}
|
||||
|
||||
Q4Matrix::Q4Matrix
|
||||
(
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _qweight,
|
||||
uint32_t* _qzeros,
|
||||
half* _scales,
|
||||
uint32_t* _g_idx,
|
||||
|
||||
const int _device
|
||||
) :
|
||||
height(_height),
|
||||
width(_width),
|
||||
groups(_groups),
|
||||
device(_device)
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
cuda_qweight = _qweight;
|
||||
cuda_qzeros = _qzeros;
|
||||
cuda_scales = _scales;
|
||||
|
||||
groupsize = height / groups;
|
||||
|
||||
if (_g_idx) make_sequential(_g_idx);
|
||||
}
|
||||
|
||||
Q4Matrix::~Q4Matrix()
|
||||
{
|
||||
}
|
||||
|
||||
// Make sequential
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const uint32_t* __restrict__ x_map,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
|
||||
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
|
||||
int w_new2_row = blockIdx.y;
|
||||
|
||||
int x_map_idx = w_new2_row << 3;
|
||||
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = x_map[x_map_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
|
||||
// Group histogram
|
||||
|
||||
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||
|
||||
// Group map
|
||||
|
||||
for (int i = 0, acc = 0; i < groups; i++)
|
||||
{
|
||||
short tmp = cpu_g_idx_map[i];
|
||||
cpu_g_idx_map[i] = acc;
|
||||
acc += tmp;
|
||||
}
|
||||
|
||||
// X map (inverse)
|
||||
|
||||
for (int row = 0; row < height; row++)
|
||||
{
|
||||
uint32_t target_group = cpu_g_idx[row];
|
||||
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||
cpu_g_idx_map[target_group]++;
|
||||
cpu_x_map_inv[row] = target_row;
|
||||
}
|
||||
|
||||
// X map
|
||||
|
||||
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||
|
||||
// Move to CUDA
|
||||
|
||||
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// Rearrange rows in w
|
||||
|
||||
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
||||
dim3 blocks
|
||||
(
|
||||
(width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
|
||||
height / 8,
|
||||
1
|
||||
);
|
||||
|
||||
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
||||
|
||||
// Replace qweights
|
||||
|
||||
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
|
||||
// Cleanup
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaFree(cuda_new_qweight);
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
}
|
||||
|
||||
__global__ void reconstruct_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out, // (y)
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int height,
|
||||
const int width,
|
||||
const int groupsize
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
|
||||
if (column >= width) return;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_q4_column w_(w, height, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
MatrixView_half w_scales_(w_scales, height / groupsize, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
|
||||
|
||||
// Groupsize version
|
||||
|
||||
int group = row / groupsize;
|
||||
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
|
||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < 32; s += 4)
|
||||
{
|
||||
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
|
||||
*out_ptr = w_item; out_ptr += out_.width;
|
||||
}
|
||||
}
|
||||
|
||||
void Q4Matrix::reconstruct(half* out)
|
||||
{
|
||||
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(width + threads.x - 1) / threads.x,
|
||||
(height / 8 + threads.y - 1) / threads.y,
|
||||
1
|
||||
);
|
||||
|
||||
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _q4_matrix_cuh
|
||||
#define _q4_matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
|
||||
class Q4Matrix
|
||||
{
|
||||
public:
|
||||
|
||||
int device;
|
||||
|
||||
int height;
|
||||
int width;
|
||||
int groups;
|
||||
int groupsize;
|
||||
|
||||
uint32_t* cuda_qweight = NULL;
|
||||
uint32_t* cuda_qzeros = NULL;
|
||||
half* cuda_scales = NULL;
|
||||
uint32_t* cuda_x_map = NULL;
|
||||
|
||||
Q4Matrix
|
||||
(
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _qweight,
|
||||
uint32_t* _qzeros,
|
||||
half* _scales,
|
||||
uint32_t* _g_idx,
|
||||
|
||||
const int _device
|
||||
);
|
||||
|
||||
~Q4Matrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
|
||||
private:
|
||||
|
||||
void make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
};
|
||||
|
||||
void g_q4_keep_matrix(Q4Matrix* m);
|
||||
void g_q4_free_matrices();
|
||||
|
||||
#endif
|
|
@ -1,12 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _tuning_h
|
||||
#define _tuning_h
|
||||
|
||||
struct ExLlamaTuning {
|
||||
int matmul_recons_thd;
|
||||
bool matmul_fused_remap;
|
||||
bool matmul_no_half2;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -1,33 +0,0 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _util_cuh
|
||||
#define _util_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define cudaUnspecified hipErrorUnknown
|
||||
#else
|
||||
#define cudaUnspecified cudaErrorApiFailureBase
|
||||
#endif
|
||||
|
||||
// React to failure on return code != cudaSuccess
|
||||
|
||||
#define _cuda_check(fn) \
|
||||
do { \
|
||||
{_cuda_err = fn;} \
|
||||
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
|
||||
} while(false)
|
||||
|
||||
// React to failure on return code == 0
|
||||
|
||||
#define _alloc_check(fn) \
|
||||
do { \
|
||||
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
|
||||
else _cuda_err = cudaSuccess; \
|
||||
} while(false)
|
||||
|
||||
#endif
|
|
@ -1,191 +0,0 @@
|
|||
#include "block_reduce.h"
|
||||
#include "cuda_util.h"
|
||||
#include "kernels.h"
|
||||
#include "ls_cub.cuh"
|
||||
|
||||
ls::cub::CachingDeviceAllocator g_allocator(true);
|
||||
|
||||
template <typename T>
|
||||
__global__ void ls_cross_entropy_fw_kernel(
|
||||
const T *__restrict__ inputs, const int *__restrict__ targets,
|
||||
float *__restrict__ outputs, float *__restrict__ nll_loss_outputs,
|
||||
const int padding_idx, const float epsilon, const int vocab_size) {
|
||||
/* step1: compute each thread's max_logit and sum_exp_logit, store in
|
||||
* max_input, sum_exp_logit */
|
||||
const int block_start = blockIdx.x * vocab_size;
|
||||
const int left_idx = block_start + threadIdx.x;
|
||||
const int right_idx = (blockIdx.x + 1) * vocab_size;
|
||||
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
|
||||
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
|
||||
int target_tid = targets[blockIdx.x];
|
||||
|
||||
if (target_tid == padding_idx) {
|
||||
if (threadIdx.x == 0) {
|
||||
nll_loss_outputs[blockIdx.x] = 0.f;
|
||||
outputs[blockIdx.x] = 0.f;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
|
||||
}
|
||||
blockReduce<ReduceType::kMax, 1>(max_input);
|
||||
__shared__ float s_max_input;
|
||||
if (threadIdx.x == 0) {
|
||||
s_max_input = max_input[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
float logit = static_cast<float>(inputs[i]) - s_max_input;
|
||||
sum_logits[0] += logit;
|
||||
sum_logits[1] += expf(logit);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 2>(sum_logits);
|
||||
__shared__ float s_sum_logit;
|
||||
__shared__ float s_sum_exp;
|
||||
if (threadIdx.x == 0) {
|
||||
s_sum_logit = sum_logits[0];
|
||||
s_sum_exp = sum_logits[1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float eps_i = epsilon / (vocab_size - 1);
|
||||
if (threadIdx.x == 0) {
|
||||
// neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max)
|
||||
float nll_loss = logf(s_sum_exp) -
|
||||
static_cast<float>(inputs[block_start + target_tid]) +
|
||||
s_max_input;
|
||||
nll_loss_outputs[blockIdx.x] = nll_loss;
|
||||
float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit;
|
||||
outputs[blockIdx.x] =
|
||||
(1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ls_cross_entropy_bw_kernel(
|
||||
const float *__restrict__ grad_outputs, const T *__restrict__ inputs,
|
||||
const int *__restrict__ targets, T *__restrict__ grad_inputs,
|
||||
const int padding_idx, const float epsilon, const int vocab_size) {
|
||||
/* step1: compute each thread's max_logit and sum_exp_logit, store in
|
||||
* max_input, sum_exp_logit */
|
||||
const int block_start = blockIdx.x * vocab_size;
|
||||
const int left_idx = block_start + threadIdx.x;
|
||||
const int right_idx = (blockIdx.x + 1) * vocab_size;
|
||||
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
|
||||
float sum_logits[1] = {0.f};
|
||||
const float grad_out = static_cast<float>(grad_outputs[0]);
|
||||
int target_tid = targets[blockIdx.x];
|
||||
|
||||
if (target_tid == padding_idx) {
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
grad_inputs[i] = 0.f;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
|
||||
}
|
||||
blockReduce<ReduceType::kMax, 1>(max_input);
|
||||
__shared__ float s_max_input;
|
||||
if (threadIdx.x == 0) {
|
||||
s_max_input = max_input[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
float logit = static_cast<float>(inputs[i]) - s_max_input;
|
||||
sum_logits[0] += expf(logit);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 1>(sum_logits);
|
||||
__shared__ float s_sum_exp;
|
||||
if (threadIdx.x == 0) {
|
||||
s_sum_exp = sum_logits[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float eps_i = epsilon / (vocab_size - 1);
|
||||
float nll_weight = 1.0 - epsilon - eps_i;
|
||||
|
||||
for (int i = left_idx; i < right_idx; i += blockDim.x) {
|
||||
float prob = expf(static_cast<float>(inputs[i]) - s_max_input) / s_sum_exp;
|
||||
float grad = 0;
|
||||
grad += (vocab_size * prob - 1) * eps_i;
|
||||
grad += prob * nll_weight;
|
||||
if ((i - block_start) == target_tid) {
|
||||
grad -= nll_weight;
|
||||
}
|
||||
grad_inputs[i] = grad_out * grad;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
|
||||
float *outputs_ptr, float *nll_loss_ptr,
|
||||
float *loss_buffer, const int padding_idx,
|
||||
const float epsilon, const int batch_size,
|
||||
const int seq_len, const int vocab_size,
|
||||
cudaStream_t stream) {
|
||||
int grid_dim = batch_size * seq_len;
|
||||
float *nll_loss_buffer = loss_buffer + grid_dim;
|
||||
ls_cross_entropy_fw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
|
||||
inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx,
|
||||
epsilon, vocab_size);
|
||||
|
||||
int num_items = grid_dim;
|
||||
void *d_temp_storage = NULL;
|
||||
size_t temp_storage_bytes = 0;
|
||||
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
|
||||
loss_buffer, outputs_ptr,
|
||||
num_items, stream));
|
||||
CHECK_GPU_ERROR(
|
||||
g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes));
|
||||
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
|
||||
loss_buffer, outputs_ptr,
|
||||
num_items, stream));
|
||||
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
|
||||
nll_loss_buffer, nll_loss_ptr,
|
||||
num_items, stream));
|
||||
CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage));
|
||||
}
|
||||
|
||||
template void launch_cross_entropy_fw<float>(
|
||||
const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
|
||||
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
|
||||
const float epsilon, const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream);
|
||||
|
||||
template void launch_cross_entropy_fw<__half>(
|
||||
const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
|
||||
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
|
||||
const float epsilon, const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
|
||||
const int *targets_ptr, T *grad_inputs_ptr,
|
||||
const int padding_idx, const float epsilon,
|
||||
const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream) {
|
||||
int grid_dim = batch_size * seq_len;
|
||||
ls_cross_entropy_bw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
|
||||
grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx,
|
||||
epsilon, vocab_size);
|
||||
}
|
||||
|
||||
template void launch_cross_entropy_bw<float>(
|
||||
const float *grad_outputs_ptr, const float *inputs_ptr,
|
||||
const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx,
|
||||
const float epsilon, const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream);
|
||||
|
||||
template void launch_cross_entropy_bw<__half>(
|
||||
const float *grad_outputs_ptr, const __half *inputs_ptr,
|
||||
const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx,
|
||||
const float epsilon, const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream);
|
|
@ -1,88 +0,0 @@
|
|||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Microsoft DeepSpeed
|
||||
This file is adapted from Microsoft DeepSpeed
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
#include "cublas_wrappers.h"
|
||||
|
||||
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const float *alpha, const float *beta, const float *A,
|
||||
const float *B, float *C, cublasGemmAlgo_t algo) {
|
||||
cublasStatus_t status =
|
||||
cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha,
|
||||
(const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k,
|
||||
(const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n,
|
||||
(const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
fprintf(stderr,
|
||||
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
|
||||
m, n, k, (int)status);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const float *alpha, const float *beta, const __half *A,
|
||||
const __half *B, __half *C, cublasGemmAlgo_t algo) {
|
||||
cublasStatus_t status = cublasGemmEx(
|
||||
handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A,
|
||||
CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F,
|
||||
(transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C,
|
||||
CUDA_R_16F, m, CUDA_R_32F, algo);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
fprintf(stderr,
|
||||
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
|
||||
m, n, k, (int)status);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
|
||||
const float *alpha, const float *beta,
|
||||
const float *A, const float *B, float *C,
|
||||
cublasOperation_t op_A, cublasOperation_t op_B,
|
||||
int stride_A, int stride_B, int stride_C,
|
||||
int batch, cublasGemmAlgo_t algo) {
|
||||
cublasStatus_t status = cublasGemmStridedBatchedEx(
|
||||
handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F,
|
||||
(op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F,
|
||||
(op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C,
|
||||
batch, CUDA_R_32F, algo);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
fprintf(stderr,
|
||||
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, "
|
||||
"error: %d) \n",
|
||||
batch, m, n, k, (int)status);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
|
||||
const float *alpha, const float *beta,
|
||||
const __half *A, const __half *B, __half *C,
|
||||
cublasOperation_t op_A, cublasOperation_t op_B,
|
||||
int stride_A, int stride_B, int stride_C,
|
||||
int batch, cublasGemmAlgo_t algo) {
|
||||
cublasStatus_t status = cublasGemmStridedBatchedEx(
|
||||
handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F,
|
||||
(op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F,
|
||||
(op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C,
|
||||
batch, CUDA_R_32F, algo);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
fprintf(stderr,
|
||||
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
|
||||
m, n, k, (int)status);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -1,169 +0,0 @@
|
|||
#include <thrust/device_vector.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/transform_reduce.h>
|
||||
#include "cuda_util.h"
|
||||
|
||||
/* GPU function guard */
|
||||
std::string _cudaGetErrorString(cudaError_t error) {
|
||||
return cudaGetErrorString(error);
|
||||
}
|
||||
|
||||
std::string _cudaGetErrorString(cublasStatus_t error) {
|
||||
switch (error) {
|
||||
case CUBLAS_STATUS_SUCCESS:
|
||||
return "CUBLAS_STATUS_SUCCESS";
|
||||
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED:
|
||||
return "CUBLAS_STATUS_NOT_INITIALIZED";
|
||||
|
||||
case CUBLAS_STATUS_ALLOC_FAILED:
|
||||
return "CUBLAS_STATUS_ALLOC_FAILED";
|
||||
|
||||
case CUBLAS_STATUS_INVALID_VALUE:
|
||||
return "CUBLAS_STATUS_INVALID_VALUE";
|
||||
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH:
|
||||
return "CUBLAS_STATUS_ARCH_MISMATCH";
|
||||
|
||||
case CUBLAS_STATUS_MAPPING_ERROR:
|
||||
return "CUBLAS_STATUS_MAPPING_ERROR";
|
||||
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED:
|
||||
return "CUBLAS_STATUS_EXECUTION_FAILED";
|
||||
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR:
|
||||
return "CUBLAS_STATUS_INTERNAL_ERROR";
|
||||
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED:
|
||||
return "CUBLAS_STATUS_NOT_SUPPORTED";
|
||||
|
||||
case CUBLAS_STATUS_LICENSE_ERROR:
|
||||
return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||
}
|
||||
return "CUBLAS_UNKNOW";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void check_gpu_error(T result, char const *const func, const char *const file,
|
||||
int const line) {
|
||||
if (result) {
|
||||
throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" +
|
||||
std::to_string(line) +
|
||||
"): " + (_cudaGetErrorString(result)) + "\n");
|
||||
}
|
||||
}
|
||||
|
||||
template void check_gpu_error<cudaError_t>(cudaError_t result,
|
||||
char const *const func,
|
||||
const char *const file,
|
||||
int const line);
|
||||
template void check_gpu_error<cublasStatus_t>(cublasStatus_t result,
|
||||
char const *const func,
|
||||
const char *const file,
|
||||
int const line);
|
||||
|
||||
template <typename T>
|
||||
void print_vec(const T *outv, std::string outn, int num_output_ele) {
|
||||
std::cout << outn << ": ";
|
||||
std::vector<T> hout(num_output_ele, (T)0);
|
||||
cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T),
|
||||
cudaMemcpyDeviceToHost);
|
||||
for (int i = 0; i < num_output_ele; i++) {
|
||||
std::cout << hout[i] << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <>
|
||||
void print_vec<__half>(const __half *outv, std::string outn,
|
||||
int num_output_ele) {
|
||||
std::cout << outn << ": ";
|
||||
std::vector<__half> hout(num_output_ele, (__half)0.f);
|
||||
cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half),
|
||||
cudaMemcpyDeviceToHost);
|
||||
for (int i = 0; i < num_output_ele; i++) {
|
||||
std::cout << __half2float(hout[i]) << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template void print_vec<float>(const float *outv, std::string outn,
|
||||
int num_output_ele);
|
||||
|
||||
template void print_vec<int>(const int *outv, std::string outn,
|
||||
int num_output_ele);
|
||||
|
||||
template void print_vec<__half>(const __half *outv, std::string outn,
|
||||
int num_output_ele);
|
||||
|
||||
template <typename T>
|
||||
T *cuda_malloc(size_t ele_num) {
|
||||
size_t byte_size = ele_num * sizeof(T);
|
||||
T *pdata = nullptr;
|
||||
CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size));
|
||||
return pdata;
|
||||
}
|
||||
|
||||
template float *cuda_malloc<float>(size_t ele_num);
|
||||
|
||||
template __half *cuda_malloc<__half>(size_t ele_num);
|
||||
|
||||
template uint8_t *cuda_malloc<uint8_t>(size_t ele_num);
|
||||
|
||||
void cuda_free(void *pdata) {
|
||||
if (pdata != nullptr) {
|
||||
cudaFree(pdata);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct _isnan {
|
||||
__device__ bool operator()(T a) const { return isnan(a); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _isnan<__half> {
|
||||
__device__ bool operator()(const __half a) const { return __hisnan(a); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct _isinf {
|
||||
__device__ bool operator()(T a) const { return isinf(a); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct _isinf<__half> {
|
||||
__device__ bool operator()(const __half a) const { return __hisinf(a); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
|
||||
std::string file, int line, cudaStream_t stream) {
|
||||
// check_nan_inf = 0 for checking nan
|
||||
// check_nan_inf = 1 for checking inf
|
||||
bool res = false;
|
||||
std::string msg = file + "(" + std::to_string(line) + "): ";
|
||||
if (check_nan_inf) {
|
||||
msg += "nan.";
|
||||
res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
|
||||
data_ptr + dsize, _isnan<T>(), false,
|
||||
thrust::logical_or<bool>());
|
||||
} else {
|
||||
msg += "inf.";
|
||||
res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
|
||||
data_ptr + dsize, _isinf<T>(), false,
|
||||
thrust::logical_or<bool>());
|
||||
}
|
||||
if (res) {
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
std::cout << msg << " [check pass]." << std::endl;
|
||||
}
|
||||
|
||||
template void check_nan_inf<float>(const float *data_ptr, int dsize,
|
||||
bool check_nan_inf, std::string file,
|
||||
int line, cudaStream_t stream);
|
||||
|
||||
template void check_nan_inf<__half>(const __half *data_ptr, int dsize,
|
||||
bool check_nan_inf, std::string file,
|
||||
int line, cudaStream_t stream);
|
File diff suppressed because it is too large
Load Diff
|
@ -1,232 +0,0 @@
|
|||
#include <cooperative_groups.h>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
/**
|
||||
@brief: fuse_transpose_bias
|
||||
Calculate the sum of elements in each column of the matrix.
|
||||
|
||||
@thread
|
||||
gridDim.x = ceil(cols / WARP_SIZE)
|
||||
blockDim.x = WARP_SIZE
|
||||
blockDim.y = WARP_SIZE
|
||||
|
||||
@param
|
||||
inp: [rows, cols]
|
||||
out: [cols]
|
||||
rows: the number of rows in the matrix
|
||||
cols: the number of cols in the matrix
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void column_sum_reduce(const T *__restrict__ inp,
|
||||
T *__restrict__ out, int rows, int cols) {
|
||||
__shared__ float tile[WARP_SIZE][WARP_SIZE];
|
||||
|
||||
cg::thread_block b = cg::this_thread_block();
|
||||
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
|
||||
|
||||
int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
|
||||
int y_stride = cols * WARP_SIZE;
|
||||
float localSum = 0;
|
||||
|
||||
// Loop across matrix row
|
||||
// TODO: optimize to log complexity
|
||||
if (idx < cols) {
|
||||
int offset = flat_2dim(threadIdx.y, idx, cols);
|
||||
for (int r = threadIdx.y; r < rows; r += WARP_SIZE) {
|
||||
localSum += (float)inp[offset];
|
||||
offset += y_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// The sum of a row in tile is equal to the sum of a col in original matrix
|
||||
tile[threadIdx.x][threadIdx.y] = localSum;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Sum the shared buffer.
|
||||
// The change of threadIdx.x is continuous
|
||||
float sum = tile[threadIdx.y][threadIdx.x];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate the sum of a row in tile
|
||||
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE);
|
||||
if (pos < cols) out[pos] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
// [r, c] -> [c]
|
||||
template <>
|
||||
void launch_fuse_transpose_bias_kernel<float>(const float *inp, float *out,
|
||||
int rows, int cols,
|
||||
cudaStream_t stream) {
|
||||
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
|
||||
dim3 block_dim(WARP_SIZE, WARP_SIZE);
|
||||
|
||||
column_sum_reduce<float>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out,
|
||||
int rows, int cols,
|
||||
cudaStream_t stream) {
|
||||
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
|
||||
dim3 block_dim(WARP_SIZE, WARP_SIZE);
|
||||
|
||||
column_sum_reduce<__half>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
|
||||
}
|
||||
|
||||
/**
|
||||
@brief: fused_add2
|
||||
Add two matrix inp1 and inp2 to out.
|
||||
|
||||
@thread
|
||||
gridDim.x = batch_size * seq_len
|
||||
blockDim.x = min(hidden_dim, MAX_THREADS)
|
||||
|
||||
@param
|
||||
inp1: [batch_size, seq_len, hidden_dim]
|
||||
inp2: [batch_size, seq_len, hidden_dim]
|
||||
out: [batch_size, seq_len, hidden_dim]
|
||||
batch_size: the size of the current batch
|
||||
seq_len: the sequence length of the current batch
|
||||
hidden_dim: dim of the hidden tensor
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2,
|
||||
int hidden_dim);
|
||||
|
||||
template <>
|
||||
__global__ void fused_add2_kernel<float>(float *out, const float *inp1,
|
||||
const float *inp2, int hidden_dim) {
|
||||
int row_id = blockIdx.x;
|
||||
int offset = flat_2dim(row_id, 0, hidden_dim);
|
||||
|
||||
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
|
||||
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
|
||||
float4 *out_4 = reinterpret_cast<float4 *>(out);
|
||||
float4 vinp1;
|
||||
float4 vinp2;
|
||||
float4 val;
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
|
||||
vinp1 = inp1_4[offset + i];
|
||||
vinp2 = inp2_4[offset + i];
|
||||
val.x = vinp1.x + vinp2.x;
|
||||
val.y = vinp1.y + vinp2.y;
|
||||
val.z = vinp1.z + vinp2.z;
|
||||
val.w = vinp1.w + vinp2.w;
|
||||
out_4[offset + i] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1,
|
||||
const __half *inp2, int hidden_dim) {
|
||||
int row_id = blockIdx.x;
|
||||
int offset = flat_2dim(row_id, 0, hidden_dim);
|
||||
|
||||
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
|
||||
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
|
||||
float4 *out_4 = reinterpret_cast<float4 *>(out);
|
||||
float4 vinp1;
|
||||
float4 vinp2;
|
||||
float4 val;
|
||||
__half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1);
|
||||
__half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2);
|
||||
__half2 *h2_val = reinterpret_cast<__half2 *>(&val);
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
|
||||
vinp1 = inp1_4[offset + i];
|
||||
vinp2 = inp2_4[offset + i];
|
||||
h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]);
|
||||
h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]);
|
||||
h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]);
|
||||
h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]);
|
||||
out_4[offset + i] = val;
|
||||
}
|
||||
}
|
||||
|
||||
//[b, s, h] -> [b, s, h]
|
||||
template <>
|
||||
void launch_fused_add2<float>(float *out, const float *inp1, const float *inp2,
|
||||
int batch_size, int seq_len, int hidden_dim,
|
||||
cudaStream_t &stream) {
|
||||
hidden_dim >>= 2;
|
||||
|
||||
dim3 grid_dim(batch_size * seq_len);
|
||||
dim3 block_dim(min(hidden_dim, MAX_THREADS));
|
||||
|
||||
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
|
||||
hidden_dim);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_fused_add2<__half>(__half *out, const __half *inp1,
|
||||
const __half *inp2, int batch_size, int seq_len,
|
||||
int hidden_dim, cudaStream_t &stream) {
|
||||
hidden_dim >>= 3;
|
||||
|
||||
dim3 grid_dim(batch_size * seq_len);
|
||||
dim3 block_dim(min(hidden_dim, MAX_THREADS));
|
||||
|
||||
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
|
||||
hidden_dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output,
|
||||
int sz0, int sz2, int sz1_1, int sz1_2) {
|
||||
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
|
||||
int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x);
|
||||
if (idx >= nele) {
|
||||
return;
|
||||
}
|
||||
float4 *dst_ptr = (float4 *)output + idx;
|
||||
int idx2 = idx % sz2;
|
||||
idx = idx / sz2;
|
||||
int idx1 = idx % (sz1_1 + sz1_2);
|
||||
int idx0 = idx / (sz1_1 + sz1_2);
|
||||
float4 *src_ptr = nullptr;
|
||||
int sz1 = 0;
|
||||
if (idx1 < sz1_1) {
|
||||
sz1 = sz1_1;
|
||||
src_ptr = (float4 *)inp1;
|
||||
} else {
|
||||
idx1 -= sz1_1;
|
||||
sz1 = sz1_2;
|
||||
src_ptr = (float4 *)inp2;
|
||||
}
|
||||
src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2);
|
||||
dst_ptr[0] = src_ptr[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_concat3_dim1<float>(const float *inp1, const float *inp2,
|
||||
float *output, int sz0, int sz2, int sz1_1,
|
||||
int sz1_2, cudaStream_t stream) {
|
||||
sz2 >>= 2;
|
||||
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
|
||||
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
|
||||
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
|
||||
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2,
|
||||
__half *output, int sz0, int sz2, int sz1_1,
|
||||
int sz1_2, cudaStream_t stream) {
|
||||
sz2 >>= 3;
|
||||
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
|
||||
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
|
||||
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
|
||||
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "cuda_util.h"
|
||||
|
||||
class Context {
|
||||
public:
|
||||
Context() : _stream(nullptr) {
|
||||
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
|
||||
}
|
||||
|
||||
virtual ~Context() {}
|
||||
|
||||
static Context &Instance() {
|
||||
static Context _ctx;
|
||||
return _ctx;
|
||||
}
|
||||
|
||||
void set_stream(cudaStream_t stream) {
|
||||
_stream = stream;
|
||||
CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream));
|
||||
}
|
||||
|
||||
cudaStream_t get_stream() { return _stream; }
|
||||
|
||||
cublasHandle_t get_cublashandle() { return _cublasHandle; }
|
||||
|
||||
private:
|
||||
cudaStream_t _stream;
|
||||
cublasHandle_t _cublasHandle;
|
||||
};
|
|
@ -1,46 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "cuda_util.h"
|
||||
|
||||
template <typename T>
|
||||
class CrossEntropyLayer {
|
||||
public:
|
||||
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
|
||||
|
||||
virtual ~CrossEntropyLayer();
|
||||
|
||||
void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
|
||||
float *nll_loss_ptr);
|
||||
|
||||
void Backward(const float *grad_outputs_ptr, const T *inputs_ptr,
|
||||
const int *targets_ptr, T *grad_inputs_ptr);
|
||||
|
||||
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
|
||||
|
||||
private:
|
||||
void allocate_mem_buffer() {
|
||||
// allocate local gpu memory
|
||||
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);
|
||||
}
|
||||
|
||||
void free_mem_buffer() {
|
||||
// free local gpu memory
|
||||
cuda_free(_loss_buffer);
|
||||
}
|
||||
|
||||
const int _padding_idx;
|
||||
const float _epsilon;
|
||||
const int _max_batch_tokens;
|
||||
|
||||
size_t _batch_size;
|
||||
size_t _seq_len;
|
||||
size_t _vocab_size;
|
||||
|
||||
float *_loss_buffer;
|
||||
};
|
|
@ -1,41 +0,0 @@
|
|||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Microsoft DeepSpeed
|
||||
This file is adapted from Microsoft DeepSpeed
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <mma.h>
|
||||
#include <stdio.h>
|
||||
|
||||
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const float *alpha, const float *beta, const float *A,
|
||||
const float *B, float *C,
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
|
||||
|
||||
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const float *alpha, const float *beta, const __half *A,
|
||||
const __half *B, __half *C,
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
||||
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
|
||||
const float *alpha, const float *beta,
|
||||
const float *A, const float *B, float *C,
|
||||
cublasOperation_t op_A, cublasOperation_t op_B,
|
||||
int stride_A, int stride_B, int stride_C,
|
||||
int batch,
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
|
||||
|
||||
int cublas_strided_batched_gemm(
|
||||
cublasHandle_t handle, int m, int n, int k, const float *alpha,
|
||||
const float *beta, const __half *A, const __half *B, __half *C,
|
||||
cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B,
|
||||
int stride_C, int batch,
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
|
@ -1,34 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#include <math_constants.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
template <typename T>
|
||||
void check_gpu_error(T result, char const *const func, const char *const file,
|
||||
int const line);
|
||||
|
||||
#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__)
|
||||
|
||||
template <typename T>
|
||||
void print_vec(const T *outv, std::string outn, int num_output_ele);
|
||||
|
||||
template <typename T>
|
||||
T *cuda_malloc(size_t ele_num);
|
||||
|
||||
void cuda_free(void *pdata);
|
||||
|
||||
template <typename T>
|
||||
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
|
||||
std::string file, int line, cudaStream_t stream);
|
||||
|
||||
#define CHECK_NAN_INF(ptr, size, stream) \
|
||||
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
|
||||
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))
|
|
@ -1,96 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
template <typename T>
|
||||
class Dropout {
|
||||
public:
|
||||
struct Config {
|
||||
float ratio;
|
||||
bool training;
|
||||
|
||||
Config(float r) : ratio(r), training(true) {}
|
||||
float RATIO() const { return training ? ratio : 0.0; }
|
||||
};
|
||||
|
||||
Dropout(const Config &config, size_t max_ele_num)
|
||||
: _config(config), _mask(nullptr) {
|
||||
_mask = cuda_malloc<uint8_t>(max_ele_num);
|
||||
}
|
||||
|
||||
virtual ~Dropout() { cuda_free(_mask); }
|
||||
|
||||
// after attention softmax
|
||||
void dropout(T *output, const T *input, int count, cudaStream_t stream,
|
||||
bool bwd = false) {
|
||||
launch_ls_dropout<T>(output, input, _mask, count, _config.RATIO(), stream,
|
||||
bwd);
|
||||
}
|
||||
|
||||
void d_dropout(T *d_inp_out, int count, cudaStream_t stream) {
|
||||
launch_ls_dropout<T>(d_inp_out, d_inp_out, _mask, count, _config.RATIO(),
|
||||
stream, true);
|
||||
}
|
||||
|
||||
// transformer layer's postprocessing dropout, after attn or ffn module,
|
||||
// before residual add.
|
||||
void bias_dropout_residual(T *output, const T *input, const T *residual,
|
||||
const T *bias, int rows, int cols,
|
||||
cudaStream_t stream) {
|
||||
launch_ls_dropout_res_bias<T>(output, input, _mask, bias, residual,
|
||||
rows * cols, cols, _config.RATIO(), stream);
|
||||
}
|
||||
|
||||
void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output,
|
||||
int rows, int cols, cudaStream_t stream) {
|
||||
launch_ls_dropout_bias_bwd<T>(d_input, d_bias, d_output, _mask, rows, cols,
|
||||
_config.RATIO(), stream);
|
||||
}
|
||||
|
||||
// dropout inside ffn.
|
||||
void bias_act_dropout(T *output, const T *input, const T *bias, int rows,
|
||||
int cols, std::string activation_fn,
|
||||
cudaStream_t stream) {
|
||||
if (activation_fn == "relu") {
|
||||
launch_ls_dropout_act_bias<ActivationType::kRelu, T>(
|
||||
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
|
||||
stream);
|
||||
} else if (activation_fn == "gelu") {
|
||||
launch_ls_dropout_act_bias<ActivationType::kGelu, T>(
|
||||
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("not supported activation: " + activation_fn);
|
||||
}
|
||||
}
|
||||
|
||||
void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input,
|
||||
const T *bias, int rows, int cols,
|
||||
std::string activation_fn, cudaStream_t stream) {
|
||||
if (activation_fn == "relu") {
|
||||
launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, T>(
|
||||
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
|
||||
_config.RATIO(), stream);
|
||||
} else if (activation_fn == "gelu") {
|
||||
launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, T>(
|
||||
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
|
||||
_config.RATIO(), stream);
|
||||
} else {
|
||||
throw std::runtime_error("not supported activation: " + activation_fn);
|
||||
}
|
||||
}
|
||||
|
||||
bool HasDropout() const { return _config.RATIO() > 0.0; }
|
||||
|
||||
void SetTrainingMode(bool training) { _config.training = training; }
|
||||
|
||||
private:
|
||||
uint8_t *_mask;
|
||||
Config _config;
|
||||
};
|
|
@ -1,69 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Microsoft DeepSpeed
|
||||
This file is adapted from Microsoft DeepSpeed
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "cublas_wrappers.h"
|
||||
#include "kernels.h"
|
||||
|
||||
template <typename T>
|
||||
class FeedForward {
|
||||
public:
|
||||
struct Config {
|
||||
int outputSize;
|
||||
int inputSize;
|
||||
std::array<int, 3> gemm_algos;
|
||||
Config(int outputs, int inputs)
|
||||
: outputSize(outputs),
|
||||
inputSize(inputs),
|
||||
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
|
||||
};
|
||||
|
||||
FeedForward(Config config) : config_(config) {}
|
||||
|
||||
~FeedForward() {}
|
||||
|
||||
void Forward(int bsz, const T *input_ptr, const T *weights, T *out,
|
||||
cublasHandle_t &_cublasHandle) {
|
||||
float alpha = T(1.);
|
||||
float beta = T(0.);
|
||||
|
||||
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize,
|
||||
bsz, config_.inputSize, &alpha, &beta, weights, input_ptr,
|
||||
out, cublasGemmAlgo_t(config_.gemm_algos[0]));
|
||||
}
|
||||
void Backward(int bsz, const T *out_grad, const T *input_ptr,
|
||||
const T *weights, T *weights_grad, T *bias_grad,
|
||||
cublasHandle_t &_cublasHandle, cudaStream_t &stream,
|
||||
T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr,
|
||||
bool compute_bias = true) {
|
||||
float alpha = (T)1.0, beta = (T)0.0;
|
||||
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize,
|
||||
config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad,
|
||||
weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1]));
|
||||
|
||||
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize,
|
||||
bsz, config_.outputSize, &alpha, &beta, weights, out_grad,
|
||||
inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2]));
|
||||
if (compute_bias) {
|
||||
launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz,
|
||||
config_.outputSize, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void reset_size(int outputSize, int inputSize) {
|
||||
config_.outputSize = outputSize;
|
||||
config_.inputSize = inputSize;
|
||||
}
|
||||
|
||||
private:
|
||||
Config config_;
|
||||
};
|
|
@ -1,275 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#define MAX_THREADS 1024
|
||||
#define WARP_SIZE 32
|
||||
|
||||
enum class ActivationType { kRelu, kGelu };
|
||||
|
||||
void launch_curand_init(int total_count, int dim, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp,
|
||||
const T *scale, const T *bias, int batch_size,
|
||||
int hidden_dim, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
|
||||
const T *residual_grad, const T *inp_or_out, const T *gamma,
|
||||
const T *betta, const T *vars, const T *means, int batch,
|
||||
int hidden_dim, cudaStream_t stream[2]);
|
||||
|
||||
template <typename T>
|
||||
void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads,
|
||||
int from_len, int to_len, bool mask_future,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows,
|
||||
int softmax_len, cudaStream_t stream);
|
||||
|
||||
// [b, s, h] -> [b, nh, s, ad]
|
||||
template <typename T>
|
||||
void launch_transform_0213(T *output, const T *vals, int batch_size,
|
||||
int seq_length, int hidden_dim, int nhead,
|
||||
cudaStream_t stream);
|
||||
|
||||
// [b, s, 3, h] -> [3, b, nh, s, ad]
|
||||
template <typename T>
|
||||
void launch_bias_add_transform_20314(T *output, const T *input, const T *bias,
|
||||
int dim_0, int dim_1, int dim_2, int dim_3,
|
||||
int dim_4, cudaStream_t stream);
|
||||
|
||||
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
|
||||
template <typename T>
|
||||
void launch_transform4d_0213(T *output, const T *vals, int batch_size,
|
||||
int seq_len, int hidden_dim, int nhead,
|
||||
int trans_count, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count,
|
||||
float ratio, cudaStream_t stream, bool backward = false);
|
||||
|
||||
template <typename T>
|
||||
void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask,
|
||||
const T *bias, const T *residual,
|
||||
int total_count, int dim, float ratio,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <ActivationType, typename T>
|
||||
void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask,
|
||||
const T *bias, int total_count, int dim,
|
||||
float ratio, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
|
||||
const uint8_t *mask, int row_size, int dim,
|
||||
float ratio, cudaStream_t stream);
|
||||
|
||||
template <ActivationType act_type, typename T>
|
||||
void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
|
||||
const T *bias, const T *out_grad,
|
||||
const uint8_t *mask, int row_size, int dim,
|
||||
float ratio, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols,
|
||||
cudaStream_t stream);
|
||||
|
||||
void launch_param_update(const float *input, __half *output, int size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0,
|
||||
int sz2, int sz1_1, int sz1_2, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size,
|
||||
int seq_len, int hidden_size, cudaStream_t &stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
|
||||
float *outputs_ptr, float *nll_loss_ptr,
|
||||
float *loss_buffer, const int padding_idx,
|
||||
const float epsilon, const int batch_size,
|
||||
const int seq_len, const int vocab_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
|
||||
const int *targets_ptr, T *grad_inputs_ptr,
|
||||
const int padding_idx, const float epsilon,
|
||||
const int batch_size, const int seq_len,
|
||||
const int vocab_size, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_lookup_scale_pos_dropout(
|
||||
T *output, const int *input, const T *embeddings, const T *pos_embeddings,
|
||||
uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim,
|
||||
int padding_idx, float dropout_ratio, int step, cudaStream_t &stream);
|
||||
|
||||
template <typename T>
|
||||
void launch_d_lookup_scale_pos_dropout(
|
||||
T *grad_embeddings, const T *grad_output, const int *input,
|
||||
const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim,
|
||||
int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream);
|
||||
|
||||
/* Convert 2-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) {
|
||||
return id1 * dim2 + id2;
|
||||
}
|
||||
|
||||
/* Convert 3-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3,
|
||||
int dim2, int dim3) {
|
||||
return id1 * dim2 * dim3 + id2 * dim3 + id3;
|
||||
}
|
||||
|
||||
/* Convert 4-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3,
|
||||
int id4, int dim2, int dim3,
|
||||
int dim4) {
|
||||
// return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4;
|
||||
int res = id4;
|
||||
|
||||
int ld = dim4;
|
||||
res += id3 * ld;
|
||||
|
||||
ld *= dim3;
|
||||
res += id2 * ld;
|
||||
|
||||
ld *= dim2;
|
||||
res += id1 * ld;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/* Convert 5-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3,
|
||||
int id4, int id5, int dim2,
|
||||
int dim3, int dim4,
|
||||
int dim5) {
|
||||
// return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) +
|
||||
// id4*dim5 + dim5;
|
||||
int res = id5;
|
||||
|
||||
int ld = dim5;
|
||||
res += id4 * ld;
|
||||
|
||||
ld *= dim4;
|
||||
res += id3 * ld;
|
||||
|
||||
ld *= dim3;
|
||||
res += id2 * ld;
|
||||
|
||||
ld *= dim2;
|
||||
res += id1 * ld;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/* Convert 6-dim tensor index into vector index */
|
||||
__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3,
|
||||
int id4, int id5, int id6,
|
||||
int dim2, int dim3, int dim4,
|
||||
int dim5, int dim6) {
|
||||
// return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) +
|
||||
// id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6;
|
||||
int res = id6;
|
||||
|
||||
int ld = dim6;
|
||||
res += id5 * ld;
|
||||
|
||||
ld *= dim5;
|
||||
res += id4 * ld;
|
||||
|
||||
ld *= dim4;
|
||||
res += id3 * ld;
|
||||
|
||||
ld *= dim3;
|
||||
res += id2 * ld;
|
||||
|
||||
ld *= dim2;
|
||||
res += id1 * ld;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/* Convert vector index to 6-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_6dim(
|
||||
int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0,
|
||||
int *id1, int *id2, int *id3, int *id4, int *id5) {
|
||||
*id5 = src % dim5;
|
||||
src /= dim5;
|
||||
|
||||
*id4 = src % dim4;
|
||||
src /= dim4;
|
||||
|
||||
*id3 = src % dim3;
|
||||
src /= dim3;
|
||||
|
||||
*id2 = src % dim2;
|
||||
src /= dim2;
|
||||
|
||||
*id1 = src % dim1;
|
||||
*id0 = src / dim1;
|
||||
}
|
||||
|
||||
/* Convert vector index to 5-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1,
|
||||
int dim2, int dim3,
|
||||
int dim4, int *id0,
|
||||
int *id1, int *id2,
|
||||
int *id3, int *id4) {
|
||||
*id4 = src % dim4;
|
||||
src /= dim4;
|
||||
|
||||
*id3 = src % dim3;
|
||||
src /= dim3;
|
||||
|
||||
*id2 = src % dim2;
|
||||
src /= dim2;
|
||||
|
||||
*id1 = src % dim1;
|
||||
*id0 = src / dim1;
|
||||
}
|
||||
|
||||
/* Convert vector index to 4-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1,
|
||||
int dim2, int dim3,
|
||||
int *id0, int *id1,
|
||||
int *id2, int *id3) {
|
||||
*id3 = src % dim3;
|
||||
src /= dim3;
|
||||
|
||||
*id2 = src % dim2;
|
||||
src /= dim2;
|
||||
|
||||
*id1 = src % dim1;
|
||||
*id0 = src / dim1;
|
||||
}
|
||||
|
||||
/* Convert vector index to 3-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1,
|
||||
int dim2, int *id0,
|
||||
int *id1, int *id2) {
|
||||
*id2 = src % dim2;
|
||||
src /= dim2;
|
||||
|
||||
*id1 = src % dim1;
|
||||
*id0 = src / dim1;
|
||||
}
|
||||
|
||||
/* Convert vector index to 2-dim tensor index */
|
||||
__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1,
|
||||
int *id0, int *id1) {
|
||||
*id1 = src % dim1;
|
||||
*id0 = src / dim1;
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
// copied from https://github.com/dmlc/dgl/pull/2758
|
||||
#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_
|
||||
#define DGL_ARRAY_CUDA_DGL_CUB_CUH_
|
||||
|
||||
#define CUB_NS_PREFIX namespace ls {
|
||||
#define CUB_NS_POSTFIX }
|
||||
#include "cub/cub.cuh"
|
||||
#include "cub/util_allocator.cuh"
|
||||
#undef CUB_NS_POSTFIX
|
||||
#undef CUB_NS_PREFIX
|
||||
|
||||
#endif
|
|
@ -1,65 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
template <typename T>
|
||||
class Normalize_Layer {
|
||||
public:
|
||||
struct Config {
|
||||
uint32_t hidden_dim;
|
||||
bool use_mean;
|
||||
Config(uint32_t hidden_dim, bool use_mean = false)
|
||||
: hidden_dim(hidden_dim), use_mean(use_mean) {}
|
||||
};
|
||||
|
||||
Normalize_Layer(Config config, size_t max_rows)
|
||||
: config_(config), vars_(nullptr), means_(nullptr) {
|
||||
vars_ = cuda_malloc<T>(max_rows);
|
||||
if (config_.use_mean) {
|
||||
means_ = cuda_malloc<T>(max_rows);
|
||||
}
|
||||
}
|
||||
|
||||
~Normalize_Layer() {
|
||||
cuda_free(vars_);
|
||||
cuda_free(means_);
|
||||
}
|
||||
|
||||
void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta,
|
||||
int batch_size, cudaStream_t stream) {
|
||||
launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size,
|
||||
config_.hidden_dim, stream);
|
||||
}
|
||||
|
||||
/*
|
||||
residual_grad, inp_or_out, betta should be treated carefully.
|
||||
inp_or_out = input if use_mean else output
|
||||
residual_grad, betta can be nullptr.
|
||||
residual_grad will be added to dinp if it is not nullptr
|
||||
which is useful in transformer layer when pre-ln
|
||||
betta are only used to compute xhat,
|
||||
(use_mean == false) ^ (betta == nullptr) should be true
|
||||
*/
|
||||
void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
|
||||
const T *residual_grad, const T *inp_or_out, const T *gamma,
|
||||
const T *betta, int batch_size, cudaStream_t stream[2]) {
|
||||
launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad,
|
||||
inp_or_out, gamma, betta, vars_, means_, batch_size,
|
||||
config_.hidden_dim, stream);
|
||||
}
|
||||
|
||||
inline bool use_mean() const { return config_.use_mean; }
|
||||
|
||||
private:
|
||||
Config config_;
|
||||
T *vars_;
|
||||
T *means_;
|
||||
};
|
|
@ -1,42 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
template <typename T>
|
||||
class Softmax {
|
||||
public:
|
||||
struct Config {
|
||||
size_t nhead;
|
||||
Config(size_t nhead) : nhead(nhead) {}
|
||||
};
|
||||
|
||||
Softmax(Config config) : config_(config) {}
|
||||
|
||||
~Softmax() {}
|
||||
|
||||
void Forward(T *vals, const T *attn_mask, int batch_size, int from_len,
|
||||
int to_len, cudaStream_t &stream, bool mask_future = true) {
|
||||
launch_attn_softmax<T>(vals, attn_mask, batch_size, config_.nhead, from_len,
|
||||
to_len, mask_future, stream);
|
||||
}
|
||||
|
||||
void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len,
|
||||
int to_len, cudaStream_t stream) {
|
||||
launch_attn_softmax_bw<T>(out_grad, soft_out,
|
||||
batch_size * config_.nhead * from_len, to_len,
|
||||
stream);
|
||||
}
|
||||
|
||||
void reset_size(size_t nhead) { config_.nhead = nhead; }
|
||||
|
||||
private:
|
||||
Config config_;
|
||||
};
|
|
@ -1,100 +0,0 @@
|
|||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Microsoft DeepSpeed
|
||||
This file is adapted from Microsoft DeepSpeed
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "cublas_wrappers.h"
|
||||
|
||||
template <typename T>
|
||||
class StridedBatchGemm {
|
||||
public:
|
||||
struct Config {
|
||||
int m;
|
||||
int n;
|
||||
int k;
|
||||
float alpha;
|
||||
float beta;
|
||||
cublasOperation_t op_A;
|
||||
cublasOperation_t op_B;
|
||||
std::array<int, 3> gemm_algos;
|
||||
|
||||
Config(float param_alpha, float param_beta, cublasOperation_t opA,
|
||||
cublasOperation_t opB)
|
||||
: alpha(param_alpha),
|
||||
beta(param_beta),
|
||||
op_A(opA),
|
||||
op_B(opB),
|
||||
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
|
||||
void SetConfig(int mm, int nn, int kk) {
|
||||
m = mm;
|
||||
n = nn;
|
||||
k = kk;
|
||||
}
|
||||
};
|
||||
|
||||
StridedBatchGemm(const Config &config) : _config(config) {}
|
||||
|
||||
virtual ~StridedBatchGemm() {}
|
||||
|
||||
void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b,
|
||||
cublasHandle_t handle) {
|
||||
int stride_a = _config.m * _config.k;
|
||||
int stride_b = _config.n * _config.k;
|
||||
int stride_c = _config.m * _config.n;
|
||||
|
||||
cublas_strided_batched_gemm(
|
||||
handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta,
|
||||
_buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a,
|
||||
stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0]));
|
||||
}
|
||||
|
||||
void Backward(int bsz, const T *d_output, const T *_buffer_a,
|
||||
const T *_buffer_b, cublasHandle_t handle,
|
||||
T *inpGradA = nullptr, T *inpGradB = nullptr) {
|
||||
int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
|
||||
int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
|
||||
|
||||
int stride_a = mb * _config.n;
|
||||
int stride_b = _config.n * kb;
|
||||
int stride_c = _config.m * _config.k;
|
||||
|
||||
// B need to transpose.
|
||||
cublasOperation_t op_b =
|
||||
(_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
|
||||
|
||||
// Calculate d_A.
|
||||
cublas_strided_batched_gemm(
|
||||
handle, mb, kb, _config.n, &_config.alpha, &_config.beta,
|
||||
(_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
|
||||
(_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA,
|
||||
CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz,
|
||||
cublasGemmAlgo_t(_config.gemm_algos[1]));
|
||||
|
||||
// A need to transpose.
|
||||
cublasOperation_t op_a =
|
||||
(_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
|
||||
|
||||
stride_a = _config.m * _config.k;
|
||||
stride_b = _config.m * _config.n;
|
||||
stride_c = _config.n * _config.k;
|
||||
|
||||
// Calculate d_B.
|
||||
cublas_strided_batched_gemm(
|
||||
handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta,
|
||||
_buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b,
|
||||
stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2]));
|
||||
}
|
||||
|
||||
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
|
||||
|
||||
private:
|
||||
Config _config;
|
||||
};
|
File diff suppressed because it is too large
Load Diff
|
@ -1,365 +0,0 @@
|
|||
#include <cooperative_groups.h>
|
||||
#include <math.h>
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include "block_reduce.h"
|
||||
#include "kernels.h"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
const float EPSILON = 1e-8f;
|
||||
|
||||
/**
|
||||
@brief: softmax_kernel
|
||||
Softmax forward kernel for
|
||||
enc-self-attn, dec-self-attn, encdec-attn
|
||||
|
||||
@thread
|
||||
gridDim.x = dynamic
|
||||
gridDim.y = batch_size
|
||||
gridDim.z = nhead
|
||||
blockDim.x = from_len
|
||||
|
||||
@param
|
||||
inp: [batch_size, nhead, from_len, to_len], softmax input.
|
||||
attn_mask: [batch_size, to_len], padding tokens are -inf,
|
||||
non padding tokens are 0.
|
||||
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
|
||||
attn_mask=nullptr and mask_future=ture for dec-self-attn training
|
||||
attn_mask=nullptr and mask_future=false for dec-self-attn infer
|
||||
*/
|
||||
template <typename T, int block_dim, int ele_per_thread>
|
||||
__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
|
||||
int to_len, bool mask_future) {
|
||||
int batch_id = blockIdx.y;
|
||||
int head_id = blockIdx.z;
|
||||
const int nhead = gridDim.z;
|
||||
const int token_per_reduce = 1;
|
||||
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
|
||||
cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
typedef cub::BlockStore<T, block_dim, ele_per_thread,
|
||||
cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
T mval[ele_per_thread];
|
||||
if (attn_mask) {
|
||||
attn_mask += batch_id * to_len;
|
||||
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
|
||||
}
|
||||
|
||||
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
|
||||
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
|
||||
token_id += gridDim.x * token_per_reduce) {
|
||||
T inp_val[token_per_reduce][ele_per_thread];
|
||||
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
|
||||
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
|
||||
REDUCE_FLOAT_INF_NEG);
|
||||
}
|
||||
|
||||
/* step 1. compute max */
|
||||
// thread local max
|
||||
float val[token_per_reduce][ele_per_thread];
|
||||
float l_max[token_per_reduce];
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
l_max[i] = REDUCE_FLOAT_INF_NEG;
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
if (attn_mask) {
|
||||
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
|
||||
} else {
|
||||
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
|
||||
val[i][j] = REDUCE_FLOAT_INF_NEG;
|
||||
} else {
|
||||
val[i][j] = (float)inp_val[i][j];
|
||||
}
|
||||
}
|
||||
l_max[i] = fmaxf(l_max[i], val[i][j]);
|
||||
}
|
||||
}
|
||||
// block reduce max
|
||||
blockReduce<ReduceType::kMax, token_per_reduce>(l_max);
|
||||
// write shared
|
||||
__shared__ float s_max[token_per_reduce];
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
s_max[i] = l_max[i];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
/* step 2. compute sum */
|
||||
// thread local sum
|
||||
float l_sum[token_per_reduce];
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
l_sum[i] = 0.f;
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
val[i][j] = __expf(val[i][j] - s_max[i]);
|
||||
l_sum[i] += val[i][j];
|
||||
}
|
||||
}
|
||||
// block reduce sum
|
||||
blockReduce<ReduceType::kSum, token_per_reduce>(l_sum);
|
||||
// write shared
|
||||
__shared__ float s_sum[token_per_reduce];
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
/* step 3. compute final result */
|
||||
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
inp_val[i][j] = (T)(val[i][j] * s_sum[i]);
|
||||
}
|
||||
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
|
||||
to_len);
|
||||
}
|
||||
} // blockIdx.x
|
||||
}
|
||||
|
||||
template <typename T, int block_dim, int ele_per_thread>
|
||||
__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
|
||||
int to_len, bool mask_future) {
|
||||
int batch_id = blockIdx.y;
|
||||
int head_id = blockIdx.z;
|
||||
const int nhead = gridDim.z;
|
||||
const int token_per_reduce = 1;
|
||||
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
|
||||
cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
typedef cub::BlockStore<T, block_dim, ele_per_thread,
|
||||
cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
T mval[ele_per_thread];
|
||||
if (attn_mask) {
|
||||
attn_mask += batch_id * to_len;
|
||||
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
|
||||
}
|
||||
|
||||
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
|
||||
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
|
||||
token_id += gridDim.x * token_per_reduce) {
|
||||
T inp_val[token_per_reduce][ele_per_thread];
|
||||
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
|
||||
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
|
||||
REDUCE_FLOAT_INF_NEG);
|
||||
}
|
||||
|
||||
/* step 1. compute max */
|
||||
// thread local max
|
||||
float val[token_per_reduce][ele_per_thread];
|
||||
float l_max[token_per_reduce];
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
l_max[i] = REDUCE_FLOAT_INF_NEG;
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
if (attn_mask) {
|
||||
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
|
||||
} else {
|
||||
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
|
||||
val[i][j] = REDUCE_FLOAT_INF_NEG;
|
||||
} else {
|
||||
val[i][j] = (float)inp_val[i][j];
|
||||
}
|
||||
}
|
||||
l_max[i] = fmaxf(l_max[i], val[i][j]);
|
||||
}
|
||||
}
|
||||
// warp reduce max
|
||||
warpReduce<ReduceType::kMax, token_per_reduce>(l_max);
|
||||
|
||||
/* step 2. compute sum */
|
||||
// thread local sum
|
||||
float l_sum[token_per_reduce];
|
||||
for (int i = 0; i < token_per_reduce; i++) {
|
||||
l_sum[i] = 0.f;
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
val[i][j] = __expf(val[i][j] - l_max[i]);
|
||||
l_sum[i] += val[i][j];
|
||||
}
|
||||
}
|
||||
// warp reduce sum
|
||||
warpReduce<ReduceType::kSum, token_per_reduce>(l_sum);
|
||||
|
||||
/* step 3. compute final result */
|
||||
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
|
||||
l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
|
||||
for (int j = 0; j < ele_per_thread; j++) {
|
||||
inp_val[i][j] = (T)(val[i][j] * l_sum[i]);
|
||||
}
|
||||
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
|
||||
to_len);
|
||||
}
|
||||
} // blockIdx.x
|
||||
}
|
||||
|
||||
/*
|
||||
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
|
||||
attn_mask=nullptr and mask_future=ture for dec-self-attn training
|
||||
attn_mask=nullptr and mask_future=false for dec-self-attn infer
|
||||
*/
|
||||
template <>
|
||||
void launch_attn_softmax<float>(float *inp, const float *attn_mask,
|
||||
int batch_size, int nhead, int from_len,
|
||||
int to_len, bool mask_future,
|
||||
cudaStream_t stream) {
|
||||
dim3 grid_dim(1, batch_size, nhead);
|
||||
if (to_len <= 32) {
|
||||
ker_attn_softmax_lt32<float, 32, 1><<<grid_dim, 32, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 64) {
|
||||
ker_attn_softmax_lt32<float, 32, 2><<<grid_dim, 32, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 128) {
|
||||
grid_dim.x = 16;
|
||||
ker_attn_softmax<float, 64, 2><<<grid_dim, 64, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 256) {
|
||||
grid_dim.x = 32;
|
||||
ker_attn_softmax<float, 128, 2><<<grid_dim, 128, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 512) {
|
||||
grid_dim.x = 64;
|
||||
ker_attn_softmax<float, 256, 2><<<grid_dim, 256, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Sequence length greater than 512 is currently not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask,
|
||||
int batch_size, int nhead, int from_len,
|
||||
int to_len, bool mask_future,
|
||||
cudaStream_t stream) {
|
||||
dim3 grid_dim(1, batch_size, nhead);
|
||||
if (to_len <= 32) {
|
||||
ker_attn_softmax_lt32<__half, 32, 1><<<grid_dim, 32, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 64) {
|
||||
ker_attn_softmax_lt32<__half, 32, 2><<<grid_dim, 32, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 128) {
|
||||
grid_dim.x = 8;
|
||||
ker_attn_softmax<__half, 64, 2><<<grid_dim, 64, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 256) {
|
||||
grid_dim.x = 16;
|
||||
ker_attn_softmax<__half, 128, 2><<<grid_dim, 128, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else if (to_len <= 512) {
|
||||
grid_dim.x = 32;
|
||||
ker_attn_softmax<__half, 256, 2><<<grid_dim, 256, 0, stream>>>(
|
||||
inp, attn_mask, from_len, to_len, mask_future);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Sequence length greater than 512 is currently not supported");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@brief: ker_attn_softmax_bw
|
||||
Softmax backward in self attention.
|
||||
|
||||
@thread
|
||||
gridDim.x = batch_size * nhead * seq_len / warps_per_block
|
||||
blockDim.x = WARP_SIZE
|
||||
blockDim.y = warps_per_block
|
||||
|
||||
@param
|
||||
grad: [batch_size, nhead, seq_len, seq_len], output grad.
|
||||
output: [batch_size, nhead, seq_len, seq_len], output of softmax forward.
|
||||
*/
|
||||
template <typename T, int ITERATIONS>
|
||||
__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
|
||||
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
int offset = batch_idx * softmax_length + threadIdx.x;
|
||||
|
||||
grad += offset;
|
||||
inp += offset;
|
||||
|
||||
T grad_reg[ITERATIONS];
|
||||
T inp_reg[ITERATIONS];
|
||||
float sum = 0.0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITERATIONS; ++i) {
|
||||
int curr_idx = threadIdx.x + i * WARP_SIZE;
|
||||
if (curr_idx < softmax_length) {
|
||||
grad_reg[i] = grad[i * WARP_SIZE];
|
||||
inp_reg[i] = inp[i * WARP_SIZE];
|
||||
sum += (float)grad_reg[i] * (float)inp_reg[i];
|
||||
}
|
||||
}
|
||||
|
||||
cg::thread_block b = cg::this_thread_block();
|
||||
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
|
||||
|
||||
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITERATIONS; ++i) {
|
||||
int curr_idx = threadIdx.x + i * WARP_SIZE;
|
||||
if (curr_idx < softmax_length)
|
||||
grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows,
|
||||
int softmax_len, cudaStream_t stream) {
|
||||
const int warps_per_block = 4;
|
||||
// rows = batch_size * nhead * from_len
|
||||
dim3 grid_dim(rows / warps_per_block);
|
||||
dim3 block_dim(WARP_SIZE, warps_per_block);
|
||||
|
||||
if (softmax_len <= 32)
|
||||
ker_attn_softmax_bw<T, 1>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 64)
|
||||
ker_attn_softmax_bw<T, 2>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 128)
|
||||
ker_attn_softmax_bw<T, 4>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 256)
|
||||
ker_attn_softmax_bw<T, 8>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 384)
|
||||
ker_attn_softmax_bw<T, 12>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 512)
|
||||
ker_attn_softmax_bw<T, 16>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 768)
|
||||
ker_attn_softmax_bw<T, 24>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 1024)
|
||||
ker_attn_softmax_bw<T, 32>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else if (softmax_len <= 2048)
|
||||
ker_attn_softmax_bw<T, 64>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
|
||||
else
|
||||
throw std::runtime_error(
|
||||
std::string(
|
||||
"Special sequence length found in softmax backward, seq_len: ") +
|
||||
std::to_string(softmax_len));
|
||||
}
|
||||
|
||||
template void launch_attn_softmax_bw<__half>(__half *out_grad,
|
||||
const __half *soft_inp, int rows,
|
||||
int softmax_len,
|
||||
cudaStream_t stream);
|
||||
template void launch_attn_softmax_bw<float>(float *out_grad,
|
||||
const float *soft_inp, int rows,
|
||||
int softmax_len,
|
||||
cudaStream_t stream);
|
|
@ -1,314 +0,0 @@
|
|||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
using namespace cub;
|
||||
|
||||
/**
|
||||
@brief: transform_0213
|
||||
Split the attention heads and reshape input
|
||||
during backward progress of encoder self-attention
|
||||
|
||||
@thread
|
||||
gridDim.x = batch_size
|
||||
gridDim.y = seq_len
|
||||
blockDim.x = min(hidden_dim, MAX_THREADS)
|
||||
|
||||
@param
|
||||
input: [batch_size, seq_len, hidden_dim]
|
||||
output: [batch_size, nhead, seq_len, head_dim]
|
||||
batch_size: the size of the current batch
|
||||
seq_len: the sequence length of the current batch
|
||||
hidden_dim: dim of the hidden tensor
|
||||
nhead: number of attention heads
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
__global__ void transform_0213(T *output, const T *input, int hidden_dim,
|
||||
int head_dim);
|
||||
|
||||
template <>
|
||||
__global__ void transform_0213<float>(float *output, const float *input,
|
||||
int hidden_dim, int head_dim) {
|
||||
int batch_id = blockIdx.x;
|
||||
int token_id = blockIdx.y;
|
||||
int seq_len = gridDim.y;
|
||||
int nhead = hidden_dim / head_dim;
|
||||
|
||||
// [b, s, h]
|
||||
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
|
||||
// [b, nh, s, ad]
|
||||
int trg_offset =
|
||||
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
|
||||
|
||||
const float4 *input4 = reinterpret_cast<const float4 *>(input);
|
||||
float4 *res4 = reinterpret_cast<float4 *>(output);
|
||||
float4 vinput4;
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
|
||||
vinput4 = input4[src_offset + i];
|
||||
|
||||
int head_id = i / head_dim;
|
||||
int dim_id = i % head_dim;
|
||||
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
|
||||
res4[trg_offset + cur_trg_offset] = vinput4;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void transform_0213<__half>(__half *output, const __half *input,
|
||||
int hidden_dim, int head_dim) {
|
||||
int batch_id = blockIdx.x;
|
||||
int token_id = blockIdx.y;
|
||||
int seq_len = gridDim.y;
|
||||
int nhead = hidden_dim / head_dim;
|
||||
|
||||
// [b, s, h]
|
||||
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
|
||||
// [b, nh, s, ad]
|
||||
int trg_offset =
|
||||
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
|
||||
|
||||
const float4 *input4 = reinterpret_cast<const float4 *>(input);
|
||||
float4 *res4 = reinterpret_cast<float4 *>(output);
|
||||
float4 vinput4;
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
|
||||
vinput4 = input4[src_offset + i];
|
||||
|
||||
int head_id = i / head_dim;
|
||||
int dim_id = i % head_dim;
|
||||
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
|
||||
res4[trg_offset + cur_trg_offset] = vinput4;
|
||||
}
|
||||
}
|
||||
|
||||
// [b, s, h] -> [b, nh, s, ad]
|
||||
template <>
|
||||
void launch_transform_0213<float>(float *output, const float *input,
|
||||
int batch_size, int seq_len, int hidden_dim,
|
||||
int nhead, cudaStream_t stream) {
|
||||
hidden_dim >>= 2;
|
||||
int head_dim = hidden_dim / nhead;
|
||||
|
||||
dim3 grid_dim(batch_size, seq_len);
|
||||
dim3 block_dim(min(hidden_dim, MAX_THREADS));
|
||||
|
||||
transform_0213<float>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_transform_0213<__half>(__half *output, const __half *input,
|
||||
int batch_size, int seq_len, int hidden_dim,
|
||||
int nhead, cudaStream_t stream) {
|
||||
hidden_dim >>= 3;
|
||||
int head_dim = hidden_dim / nhead;
|
||||
|
||||
dim3 grid_dim(batch_size, seq_len);
|
||||
dim3 block_dim(min(hidden_dim, MAX_THREADS));
|
||||
|
||||
transform_0213<__half>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
|
||||
}
|
||||
|
||||
/**
|
||||
@brief: bias_add_transform_20314
|
||||
Add bias to input, transform from
|
||||
[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4]
|
||||
|
||||
@thread
|
||||
gridDim.x = dim_0
|
||||
gridDim.y = dim_1
|
||||
gridDim.z = dim_2
|
||||
blockDim.x = min(dim_3 * dim_4, MAX_THREADS)
|
||||
|
||||
@param
|
||||
input: [dim_0, dim_1, dim_2, dim_3, dim_4]
|
||||
bias: [dim_2, dim_3, dim_4]
|
||||
output: [dim_2, dim_0, dim_3, dim_1, dim_4]
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void bias_add_transform_20314(T *output, const T *input,
|
||||
const T *bias, int dim_3, int dim_4);
|
||||
|
||||
template <>
|
||||
__global__ void bias_add_transform_20314<float>(float *output,
|
||||
const float *input,
|
||||
const float *bias, int dim_3,
|
||||
int dim_4) {
|
||||
int id0 = blockIdx.x;
|
||||
int id1 = blockIdx.y;
|
||||
int id2 = blockIdx.z;
|
||||
int dim_0 = gridDim.x;
|
||||
int dim_1 = gridDim.y;
|
||||
int dim_2 = gridDim.z;
|
||||
int dim_34 = dim_3 * dim_4;
|
||||
|
||||
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
|
||||
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
|
||||
int bias_offset = flat_2dim(id2, 0, dim_34);
|
||||
|
||||
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
|
||||
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
|
||||
float4 *res4 = reinterpret_cast<float4 *>(output);
|
||||
float4 vqkv4;
|
||||
float4 vbias4;
|
||||
float4 vres4;
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
|
||||
vqkv4 = qkv4[src_offset + i];
|
||||
vbias4 = bias4[bias_offset + i];
|
||||
vres4.x = vqkv4.x + vbias4.x;
|
||||
vres4.y = vqkv4.y + vbias4.y;
|
||||
vres4.z = vqkv4.z + vbias4.z;
|
||||
vres4.w = vqkv4.w + vbias4.w;
|
||||
|
||||
int id3 = i / dim_4;
|
||||
int id4 = i % dim_4;
|
||||
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
|
||||
res4[trg_offset + cur_trg_offset] = vres4;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void bias_add_transform_20314<__half>(__half *output,
|
||||
const __half *input,
|
||||
const __half *bias, int dim_3,
|
||||
int dim_4) {
|
||||
int id0 = blockIdx.x;
|
||||
int id1 = blockIdx.y;
|
||||
int id2 = blockIdx.z;
|
||||
int dim_0 = gridDim.x;
|
||||
int dim_1 = gridDim.y;
|
||||
int dim_2 = gridDim.z;
|
||||
int dim_34 = dim_3 * dim_4;
|
||||
|
||||
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
|
||||
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
|
||||
int bias_offset = flat_2dim(id2, 0, dim_34);
|
||||
|
||||
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
|
||||
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
|
||||
float4 *res4 = reinterpret_cast<float4 *>(output);
|
||||
float4 vqkv4;
|
||||
float4 vbias4;
|
||||
float4 vres4;
|
||||
__half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4);
|
||||
__half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4);
|
||||
__half2 *h2_res = reinterpret_cast<__half2 *>(&vres4);
|
||||
|
||||
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
|
||||
vqkv4 = qkv4[src_offset + i];
|
||||
vbias4 = bias4[bias_offset + i];
|
||||
h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]);
|
||||
h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]);
|
||||
h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]);
|
||||
h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]);
|
||||
|
||||
int id3 = i / dim_4;
|
||||
int id4 = i % dim_4;
|
||||
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
|
||||
res4[trg_offset + cur_trg_offset] = vres4;
|
||||
}
|
||||
}
|
||||
|
||||
// [b, s, 3, h] -> [3, b, nh, s, ad]
|
||||
template <>
|
||||
void launch_bias_add_transform_20314<float>(float *output, const float *input,
|
||||
const float *bias, int dim_0,
|
||||
int dim_1, int dim_2, int dim_3,
|
||||
int dim_4, cudaStream_t stream) {
|
||||
dim_4 >>= 2;
|
||||
|
||||
dim3 grid_dim(dim_0, dim_1, dim_2);
|
||||
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
|
||||
|
||||
bias_add_transform_20314<float>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_bias_add_transform_20314<__half>(__half *output,
|
||||
const __half *input,
|
||||
const __half *bias, int dim_0,
|
||||
int dim_1, int dim_2, int dim_3,
|
||||
int dim_4, cudaStream_t stream) {
|
||||
dim_4 >>= 3;
|
||||
|
||||
dim3 grid_dim(dim_0, dim_1, dim_2);
|
||||
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
|
||||
|
||||
bias_add_transform_20314<__half>
|
||||
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
|
||||
}
|
||||
|
||||
/**
|
||||
@brief: transform4d_0213
|
||||
Reshape the input matrix to merge the heads
|
||||
|
||||
@thread
|
||||
gridDim.x = (num_all + max_block_thread - 1) / max_block_thread
|
||||
blockDim.x = max_block_thread
|
||||
|
||||
@param
|
||||
input: [trans_count, batch_size, nhead, seq_len, head_dim]
|
||||
output: [batch_size, seq_len, trans_count, nhead, head_dim]
|
||||
batch_size: the size of the current batch
|
||||
seq_len: the sequence length of the current batch
|
||||
hidden_dim: dim of the hidden tensor
|
||||
nhead: number of attention heads
|
||||
trans_count: 1 or 3, the count of matrice need to be transformed
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void transform4d_0213(T *output, const T *input, int batch_size,
|
||||
int seq_len, int trans_count, int nhead,
|
||||
int head_dim, int num_all) {
|
||||
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (offset >= num_all) {
|
||||
return;
|
||||
}
|
||||
int trans_id, batch_id, head_id, token_id, dim_id;
|
||||
decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id,
|
||||
&batch_id, &head_id, &token_id, &dim_id);
|
||||
// [b, s, tc, nh, ad]
|
||||
int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id,
|
||||
seq_len, trans_count, nhead, head_dim);
|
||||
|
||||
const float4 *input4 = reinterpret_cast<const float4 *>(input);
|
||||
float4 *res4 = reinterpret_cast<float4 *>(output);
|
||||
res4[trg_offset] = input4[offset];
|
||||
}
|
||||
|
||||
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
|
||||
template <>
|
||||
void launch_transform4d_0213<float>(float *output, const float *input,
|
||||
int batch_size, int seq_len, int hidden_dim,
|
||||
int nhead, int trans_count,
|
||||
cudaStream_t stream) {
|
||||
hidden_dim >>= 2;
|
||||
int head_dim = hidden_dim / nhead;
|
||||
int num_all = batch_size * seq_len * trans_count * hidden_dim;
|
||||
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
|
||||
|
||||
transform4d_0213<float><<<nblock, MAX_THREADS, 0, stream>>>(
|
||||
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
|
||||
num_all);
|
||||
}
|
||||
|
||||
template <>
|
||||
void launch_transform4d_0213<__half>(__half *output, const __half *input,
|
||||
int batch_size, int seq_len,
|
||||
int hidden_dim, int nhead, int trans_count,
|
||||
cudaStream_t stream) {
|
||||
hidden_dim >>= 3;
|
||||
int head_dim = hidden_dim / nhead;
|
||||
int num_all = batch_size * seq_len * trans_count * hidden_dim;
|
||||
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
|
||||
|
||||
transform4d_0213<__half><<<nblock, MAX_THREADS, 0, stream>>>(
|
||||
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
|
||||
num_all);
|
||||
}
|
|
@ -1,406 +0,0 @@
|
|||
#include "multihead_attention_1d.h"
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#if TORCH_VERSION_MAJOR > 1 || \
|
||||
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
|
||||
#include <torch/csrc/distributed/c10d/Types.hpp>
|
||||
#else
|
||||
#include <c10d/Types.hpp>
|
||||
#endif
|
||||
#include <iostream>
|
||||
|
||||
#include "context.h"
|
||||
#include "kernels.h"
|
||||
|
||||
template <typename T>
|
||||
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens,
|
||||
int max_seq_len, int hidden_size,
|
||||
int num_heads,
|
||||
float attn_prob_dropout_ratio,
|
||||
float hidden_output_dropout_ratio,
|
||||
bool pre_or_postLayerNorm)
|
||||
: _layer_id(layer_id),
|
||||
_max_batch_tokens(max_batch_tokens),
|
||||
_max_seq_len(max_seq_len),
|
||||
_hidden_size(hidden_size),
|
||||
_heads(num_heads),
|
||||
_training(true),
|
||||
_pre_or_postLayerNorm(pre_or_postLayerNorm),
|
||||
_qkv_linear(
|
||||
typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
|
||||
_attn_out_linear(
|
||||
typename FeedForward<T>::Config(hidden_size, hidden_size)),
|
||||
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false),
|
||||
_max_batch_tokens),
|
||||
_softmax(typename Softmax<T>::Config(num_heads)),
|
||||
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio),
|
||||
_max_batch_tokens * _heads * _max_seq_len),
|
||||
_attn_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio),
|
||||
_max_batch_tokens * _hidden_size),
|
||||
_attn_scores(typename StridedBatchGemm<T>::Config(
|
||||
(T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T,
|
||||
CUBLAS_OP_N)),
|
||||
_attn_context(typename StridedBatchGemm<T>::Config(
|
||||
T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
|
||||
assert(_hidden_size % _heads == 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MultiHeadAttention<T>::~MultiHeadAttention() {
|
||||
free_mem_buffer();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr,
|
||||
const T *input_mask_ptr,
|
||||
T *output_ptr, T *buffer) {
|
||||
T *q_tf_ptr = _qkv_ptr;
|
||||
T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
|
||||
T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr,
|
||||
_batch_tokens, _stream);
|
||||
}
|
||||
const T *gemmQKV_inp_ptr =
|
||||
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
|
||||
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer,
|
||||
_cublasHandle);
|
||||
|
||||
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr,
|
||||
_batch_size, _seq_len, 3, _heads / pg_size,
|
||||
_hidden_size / _heads, _stream);
|
||||
|
||||
// attention scores, q*k
|
||||
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr,
|
||||
_cublasHandle);
|
||||
|
||||
// Softmax + Mask
|
||||
_softmax.reset_size(_heads / pg_size);
|
||||
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len,
|
||||
_seq_len, _stream, true);
|
||||
|
||||
// attn prob dropout.
|
||||
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr,
|
||||
_batch_heads * _seq_len * _seq_len, _stream);
|
||||
|
||||
// attention context, score * v
|
||||
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr,
|
||||
_cublasHandle);
|
||||
|
||||
// [b, nh, s, ad] -> [b, s, nh, ad]
|
||||
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len,
|
||||
_hidden_size / pg_size, _heads / pg_size, 1,
|
||||
_stream);
|
||||
|
||||
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
|
||||
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr,
|
||||
output_ptr, _cublasHandle);
|
||||
|
||||
// allreduce
|
||||
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
|
||||
} else {
|
||||
auto data_type = torch::kFloat;
|
||||
if (typeid(T) != typeid(float)) {
|
||||
data_type = torch::kHalf;
|
||||
}
|
||||
auto output_tensor = torch::from_blob(
|
||||
output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::TensorOptions(torch::kCUDA).dtype(data_type));
|
||||
std::vector<torch::Tensor> allreduce_tensors = {output_tensor};
|
||||
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
|
||||
work->wait();
|
||||
}
|
||||
|
||||
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr,
|
||||
_attn_ob_ptr, _batch_tokens, _hidden_size,
|
||||
_stream);
|
||||
if (!_pre_or_postLayerNorm) {
|
||||
// in-place ln since ln-input will not be used in post-ln mode
|
||||
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr,
|
||||
_batch_tokens, _stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
|
||||
T *out_ptr) {
|
||||
_stream = Context::Instance().get_stream();
|
||||
_cublasHandle = Context::Instance().get_cublashandle();
|
||||
T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim
|
||||
|
||||
attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
|
||||
const T *input_mask_ptr,
|
||||
const T *output_ptr,
|
||||
const T *grad_output_ptr,
|
||||
T *grad_input_ptr, T *buffer) {
|
||||
cudaStream_t streams[2] = {_stream, _stream};
|
||||
|
||||
const T *q_tf_ptr = _qkv_ptr;
|
||||
const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
|
||||
const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
|
||||
// batch_dim = batch_size * seq_len * hidden_size
|
||||
// buffer size: batch_dim * 3 + max(batch_dim * 3,
|
||||
// batch_size * head_num * seq_len * seq_len)
|
||||
T *grad_residual_ptr = buffer;
|
||||
buffer += _batch_dim;
|
||||
|
||||
T *grad_input_buf_ptr = buffer; // batch_dim
|
||||
T *grad_qkv_5d_ptr = buffer; // batch_dim * 3
|
||||
buffer += 3 * _batch_dim / pg_size;
|
||||
|
||||
T *grad_qkv_4d_ptr = buffer; // batch_dim * 3
|
||||
T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len
|
||||
// buffer += max(3 * _batch_dim,
|
||||
// batch_size * head_num * seq_len * seq_len);
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
|
||||
grad_output_ptr, _batch_tokens,
|
||||
_hidden_size, _stream);
|
||||
} else {
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr,
|
||||
grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr,
|
||||
_attn_nb_ptr, _batch_tokens, streams);
|
||||
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
|
||||
grad_residual_ptr, _batch_tokens,
|
||||
_hidden_size, _stream);
|
||||
}
|
||||
|
||||
// bw of output project
|
||||
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
|
||||
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr,
|
||||
_attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr,
|
||||
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
|
||||
false);
|
||||
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size,
|
||||
_seq_len, _hidden_size / pg_size, _heads / pg_size,
|
||||
_stream);
|
||||
|
||||
// bw of score * v
|
||||
_attn_context.Backward(
|
||||
_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
|
||||
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
|
||||
|
||||
_attn_prob_dropout.d_dropout(grad_softmax_ptr,
|
||||
_batch_heads * _seq_len * _seq_len, _stream);
|
||||
|
||||
_softmax.reset_size(_heads / pg_size);
|
||||
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len,
|
||||
_seq_len, _stream);
|
||||
|
||||
// bw of q * k
|
||||
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr,
|
||||
_cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size,
|
||||
grad_qkv_5d_ptr);
|
||||
|
||||
// [3, b, nh, s, ad] -> [b, s, 3, h]
|
||||
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size,
|
||||
_seq_len, _hidden_size / pg_size, _heads / pg_size,
|
||||
3, _stream);
|
||||
|
||||
const T *gemmQKV_inp_ptr =
|
||||
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
|
||||
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
|
||||
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr,
|
||||
_attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr,
|
||||
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
|
||||
true);
|
||||
|
||||
// allreduce
|
||||
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
|
||||
} else {
|
||||
auto data_type = torch::kFloat;
|
||||
if (typeid(T) != typeid(float)) {
|
||||
data_type = torch::kHalf;
|
||||
}
|
||||
auto grad_input_tensor =
|
||||
torch::from_blob(grad_input_buf_ptr,
|
||||
{int(_batch_size), int(_seq_len), int(_hidden_size)},
|
||||
torch::TensorOptions(torch::kCUDA).dtype(data_type));
|
||||
std::vector<torch::Tensor> allreduce_tensors = {grad_input_tensor};
|
||||
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
|
||||
work->wait();
|
||||
}
|
||||
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr,
|
||||
grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr,
|
||||
_attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
|
||||
} else {
|
||||
// FIXME later
|
||||
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr,
|
||||
_batch_size, _seq_len, _hidden_size, _stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr,
|
||||
const T *input_ptr, const T *output_ptr,
|
||||
const T *input_mask_ptr,
|
||||
T *grad_input_ptr) {
|
||||
_stream = Context::Instance().get_stream();
|
||||
_cublasHandle = Context::Instance().get_cublashandle();
|
||||
T *buffer = _shared_mem_ptr;
|
||||
|
||||
/*
|
||||
buffer size needed by attn bw:
|
||||
4 * _batch_dim + max(3 * _batch_dim,
|
||||
_batch_size * _head_num * _seq_len * _seq_len);
|
||||
*/
|
||||
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr,
|
||||
grad_input_ptr, buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiHeadAttention<T>::SetTrainingMode(bool training) {
|
||||
// Dropout will be skipped when not in training model.
|
||||
_attn_prob_dropout.SetTrainingMode(training);
|
||||
_attn_dropout.SetTrainingMode(training);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T *MultiHeadAttention<T>::_shared_mem_ptr = nullptr;
|
||||
|
||||
template class MultiHeadAttention<float>;
|
||||
template class MultiHeadAttention<__half>;
|
||||
|
||||
// x is torch::Tensor
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_multihead_attention;
|
||||
|
||||
template <typename T>
|
||||
int create_multihead_attention(int layer_id, int max_batch_tokens,
|
||||
int max_seq_len, int hidden_dim, int num_heads,
|
||||
float attn_prob_dropout_ratio,
|
||||
float hidden_dropout_ratio,
|
||||
bool pre_or_postLayerNorm,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
Context::Instance().set_stream(stream);
|
||||
auto layer = std::make_shared<MultiHeadAttention<T>>(
|
||||
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads,
|
||||
attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm);
|
||||
|
||||
layer->SetPG(pg_);
|
||||
|
||||
s_multihead_attention[layer_id] = layer;
|
||||
|
||||
std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<torch::Tensor> multihead_attention_fw(
|
||||
int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask,
|
||||
const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias,
|
||||
const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias,
|
||||
const torch::Tensor &norm_weight, const torch::Tensor &norm_bias,
|
||||
bool training_mode, bool prelayernorm) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(input_mask);
|
||||
|
||||
const T *input_ptr = (const T *)input.data_ptr();
|
||||
const T *input_mask_ptr = (const T *)input_mask.data_ptr();
|
||||
|
||||
auto output = torch::empty_like(input);
|
||||
T *out_ptr = (T *)output.data_ptr();
|
||||
|
||||
std::shared_ptr<MultiHeadAttention<T>> layer =
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(
|
||||
s_multihead_attention[layer_id]);
|
||||
layer->set_cur_batch_shape(input.size(0), input.size(1));
|
||||
layer->SetTrainingMode(training_mode);
|
||||
|
||||
layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr();
|
||||
layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr();
|
||||
layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr();
|
||||
layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr();
|
||||
layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr();
|
||||
layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr();
|
||||
|
||||
layer->Forward(input_ptr, input_mask_ptr, out_ptr);
|
||||
|
||||
return {output};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<torch::Tensor> multihead_attention_bw(
|
||||
int layer_id, const torch::Tensor &grad_dec_output,
|
||||
const torch::Tensor &output, const torch::Tensor &input,
|
||||
const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight,
|
||||
const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight,
|
||||
const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight,
|
||||
const torch::Tensor &norm_bias) {
|
||||
auto g_output = grad_dec_output.contiguous();
|
||||
CHECK_INPUT(g_output);
|
||||
CHECK_INPUT(output);
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(input_mask);
|
||||
|
||||
auto grad_input = torch::empty_like(input);
|
||||
auto grad_in_proj_weight = torch::empty_like(in_proj_weight);
|
||||
auto grad_in_proj_bias = torch::empty_like(in_proj_bias);
|
||||
auto grad_out_proj_weight = torch::empty_like(out_proj_weight);
|
||||
auto grad_out_proj_bias = torch::empty_like(out_proj_bias);
|
||||
auto grad_norm_weight = torch::empty_like(norm_weight);
|
||||
auto grad_norm_bias = torch::empty_like(norm_bias);
|
||||
|
||||
// inputs.
|
||||
const T *grad_dec_output_ptr = (const T *)g_output.data_ptr();
|
||||
const T *input_ptr = (const T *)input.data_ptr();
|
||||
const T *output_ptr = (const T *)output.data_ptr();
|
||||
const T *input_mask_ptr = (const T *)input_mask.data_ptr();
|
||||
|
||||
// outputs.
|
||||
T *grad_input_ptr = (T *)grad_input.data_ptr();
|
||||
|
||||
std::shared_ptr<MultiHeadAttention<T>> layer =
|
||||
std::static_pointer_cast<MultiHeadAttention<T>>(
|
||||
s_multihead_attention[layer_id]);
|
||||
layer->set_cur_batch_shape(g_output.size(0), g_output.size(1));
|
||||
|
||||
layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr();
|
||||
layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr();
|
||||
layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr();
|
||||
layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr();
|
||||
layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr();
|
||||
layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr();
|
||||
|
||||
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr,
|
||||
grad_input_ptr);
|
||||
|
||||
return {grad_input, grad_in_proj_weight, grad_in_proj_bias,
|
||||
grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight,
|
||||
grad_norm_bias};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("multihead_attention_fw_fp32", &multihead_attention_fw<float>,
|
||||
"Multi-head Attention forward with fp32 (CUDA)");
|
||||
m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>,
|
||||
"Multi-head Attention forward with fp16 (CUDA)");
|
||||
m.def("multihead_attention_bw_fp32", &multihead_attention_bw<float>,
|
||||
"Multi-head Attention backward with fp32 (CUDA)");
|
||||
m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>,
|
||||
"Multi-head Attention backward with fp16 (CUDA)");
|
||||
m.def("create_multihead_attention_fp32", &create_multihead_attention<float>,
|
||||
"Create Multi-head Attention with fp32 (CUDA)");
|
||||
m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>,
|
||||
"Create Multi-head Attention with fp16 (CUDA)");
|
||||
}
|
|
@ -1,167 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#if TORCH_VERSION_MAJOR > 1 || \
|
||||
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#else
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "cuda_util.h"
|
||||
#include "dropout.h"
|
||||
#include "feed_forward.h"
|
||||
#include "normalize_layer.h"
|
||||
#include "softmax.h"
|
||||
#include "strided_batch_gemm.h"
|
||||
|
||||
template <typename T>
|
||||
class MultiHeadAttention {
|
||||
public:
|
||||
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len,
|
||||
int hidden_size, int num_heads, float attn_dropout_ratio,
|
||||
float hidden_output_dropout_ratio,
|
||||
bool pre_or_postLayerNorm);
|
||||
|
||||
virtual ~MultiHeadAttention();
|
||||
|
||||
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
|
||||
|
||||
void Backward(const T *grad_output_ptr, const T *input_ptr,
|
||||
const T *output_ptr, const T *input_mask_ptr,
|
||||
T *grad_input_ptr);
|
||||
|
||||
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr,
|
||||
T *buffer);
|
||||
|
||||
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr,
|
||||
const T *output_ptr, const T *grad_output_ptr,
|
||||
T *grad_input_attn_layer_bwptr, T *buffer);
|
||||
|
||||
void set_cur_batch_shape(int batch_size, int seq_len) {
|
||||
_batch_size = batch_size;
|
||||
_seq_len = seq_len;
|
||||
_batch_tokens = batch_size * seq_len;
|
||||
_batch_heads = batch_size * _heads / pg_size;
|
||||
_batch_dim = _batch_tokens * _hidden_size;
|
||||
_attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads);
|
||||
_attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len);
|
||||
}
|
||||
|
||||
void SetTrainingMode(bool training);
|
||||
inline bool IsTrainingMode() const { return _training; }
|
||||
|
||||
void SetPG(c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
|
||||
pg = pg_;
|
||||
pg_size = 1;
|
||||
if (pg != c10::detail::UniqueVoidPtr()) {
|
||||
pg_size = pg->getSize();
|
||||
}
|
||||
allocate_mem_buffer();
|
||||
}
|
||||
|
||||
// weights ptr
|
||||
const T *_attn_qkvw_ptr;
|
||||
const T *_attn_qkvb_ptr;
|
||||
const T *_attn_ow_ptr;
|
||||
const T *_attn_ob_ptr;
|
||||
const T *_attn_nw_ptr;
|
||||
const T *_attn_nb_ptr;
|
||||
|
||||
// grads ptr
|
||||
T *_grad_attn_qkvw_ptr;
|
||||
T *_grad_attn_qkvb_ptr;
|
||||
T *_grad_attn_ow_ptr;
|
||||
T *_grad_attn_ob_ptr;
|
||||
T *_grad_attn_nw_ptr;
|
||||
T *_grad_attn_nb_ptr;
|
||||
|
||||
private:
|
||||
void allocate_mem_buffer() {
|
||||
// allocate local gpu memory
|
||||
if (_pre_or_postLayerNorm) {
|
||||
_gemmQKV_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
|
||||
} else {
|
||||
_gemmQKV_inp_ptr = nullptr;
|
||||
}
|
||||
|
||||
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
|
||||
_soft_out_ptr =
|
||||
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
_ctx_bufB_ptr =
|
||||
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
|
||||
|
||||
// buffer size needed by attn bw
|
||||
size_t smem_size =
|
||||
4 * _max_batch_tokens * _hidden_size / pg_size +
|
||||
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
|
||||
_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||
|
||||
if (!_shared_mem_ptr) {
|
||||
cuda_free(_shared_mem_ptr);
|
||||
_shared_mem_ptr = cuda_malloc<T>(smem_size);
|
||||
}
|
||||
}
|
||||
|
||||
void free_mem_buffer() {
|
||||
// free local gpu memory
|
||||
cuda_free(_gemmQKV_inp_ptr);
|
||||
cuda_free(_qkv_ptr);
|
||||
cuda_free(_soft_out_ptr);
|
||||
cuda_free(_ctx_bufB_ptr);
|
||||
cuda_free(_attn_o_inp_ptr);
|
||||
|
||||
// free shared gpu memory between layers
|
||||
cuda_free(_shared_mem_ptr);
|
||||
_shared_mem_ptr = nullptr;
|
||||
}
|
||||
|
||||
// const parameter between batch
|
||||
const size_t _layer_id;
|
||||
const size_t _hidden_size;
|
||||
const size_t _heads;
|
||||
const size_t _max_batch_tokens;
|
||||
const size_t _max_seq_len;
|
||||
const bool _pre_or_postLayerNorm;
|
||||
// dynamic parameter between batch
|
||||
size_t _batch_size;
|
||||
size_t _seq_len;
|
||||
size_t _batch_tokens;
|
||||
size_t _batch_heads;
|
||||
size_t _batch_dim;
|
||||
bool _training;
|
||||
|
||||
cublasHandle_t _cublasHandle;
|
||||
cudaStream_t _stream;
|
||||
|
||||
// layers
|
||||
FeedForward<T> _qkv_linear;
|
||||
FeedForward<T> _attn_out_linear;
|
||||
Normalize_Layer<T> _attn_ln;
|
||||
Softmax<T> _softmax;
|
||||
Dropout<T> _attn_prob_dropout;
|
||||
Dropout<T> _attn_dropout;
|
||||
StridedBatchGemm<T> _attn_scores;
|
||||
StridedBatchGemm<T> _attn_context;
|
||||
|
||||
// local GPU memory
|
||||
T *_gemmQKV_inp_ptr;
|
||||
T *_qkv_ptr;
|
||||
T *_soft_out_ptr;
|
||||
T *_ctx_bufB_ptr;
|
||||
T *_attn_o_inp_ptr;
|
||||
// shared GPU memory between layer
|
||||
static T *_shared_mem_ptr;
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg;
|
||||
int pg_size;
|
||||
};
|
|
@ -1,8 +0,0 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
#include "linear.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
|
||||
"Linear SiLU (INT8)");
|
||||
}
|
|
@ -1,162 +0,0 @@
|
|||
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu
|
||||
|
||||
#include "linear.h"
|
||||
#include <cutlass/core_io.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/half.h>
|
||||
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/util/host_tensor.h>
|
||||
#include <cutlass/epilogue/thread/linear_combination_silu.h>
|
||||
#include <cstdint>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <iostream>
|
||||
#include <torch/torch.h>
|
||||
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
|
||||
torch::Tensor weight, // INT8
|
||||
torch::Tensor bias, // FP32
|
||||
float alpha, // FP32
|
||||
float beta // FP32
|
||||
) {
|
||||
auto M = input.size(0);
|
||||
auto N = weight.size(0);
|
||||
auto K = input.size(1);
|
||||
|
||||
using ElementOutput = float;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementComputeEpilogue = float;
|
||||
using ElementInputA = int8_t; // <- data type of elements in input matrix A
|
||||
using ElementInputB = int8_t; // <- data type of elements in input matrix B
|
||||
|
||||
// The code section below describes matrix layout of input and output
|
||||
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
|
||||
// for Matrix C
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
#if CUDA_ARCH >= 800
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementOutput>::value, // <- this is the number of elements per
|
||||
// vectorized memory access. For half
|
||||
// precision, it's 8 elements. This
|
||||
// becomes the vector width of math
|
||||
// instructions in epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue // <- data type for alpha in linear combination
|
||||
// function
|
||||
>;
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
|
||||
EpilogueOp,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
|
||||
#elif CUDA_ARCH >= 750
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementOutput>::value, // <- this is the number of elements per
|
||||
// vectorized memory access. For half
|
||||
// precision, it's 8 elements. This
|
||||
// becomes the vector width of math
|
||||
// instructions in epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue // <- data type for alpha in linear combination
|
||||
// function
|
||||
>;
|
||||
|
||||
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
|
||||
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
|
||||
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
|
||||
DefaultGemmCfg::InstructionShape,
|
||||
EpilogueOp>;
|
||||
#elif CUDA_ARCH >= 700
|
||||
#define USE_TORCH_SILU
|
||||
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
|
||||
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
|
||||
using Gemm = cutlass::gemm::device::Gemm<
|
||||
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
|
||||
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
|
||||
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
|
||||
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
|
||||
DefaultGemmCfg::InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
|
||||
#else
|
||||
#error "Unsupported cuda arch"
|
||||
#endif
|
||||
|
||||
auto input_size = cutlass::MatrixCoord(M, K);
|
||||
auto weight_size = cutlass::MatrixCoord(K, N);
|
||||
auto output_size = cutlass::MatrixCoord(M, N);
|
||||
|
||||
auto device = input.device();
|
||||
// use the broadcasted bias as the output
|
||||
auto out = bias.to(device).view({1, -1}).repeat({M, 1});
|
||||
|
||||
// constexpr int kSparse = Gemm::kSparse;
|
||||
// How many elements of A are covered per ElementE
|
||||
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
|
||||
// The size of individual meta data
|
||||
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
|
||||
cutlass::gemm::GemmCoord problem_size(M, N, K);
|
||||
|
||||
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
|
||||
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size));
|
||||
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
|
||||
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
|
||||
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
|
||||
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size));
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
problem_size, // <- problem size of matrix multiplication
|
||||
input_ref, // <- reference to matrix A on device
|
||||
weight_ref, // <- reference to matrix B on device
|
||||
out_ref, // <- reference to matrix C on device
|
||||
out_ref, // <- reference to matrix D on device
|
||||
{alpha, beta}, 1};
|
||||
Gemm gemm_op;
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix
|
||||
// multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
throw std::runtime_error("cutlass cannot implement");
|
||||
}
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
throw std::runtime_error("cutlass cannot initialize");
|
||||
}
|
||||
|
||||
status = gemm_op();
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
throw std::runtime_error("cutlass cannot run");
|
||||
}
|
||||
#ifdef USE_TORCH_SILU
|
||||
#undef USE_TORCH_SILU
|
||||
out = torch::silu(out);
|
||||
#endif
|
||||
return out;
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
#include <torch/torch.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
|
||||
torch::Tensor weight, // INT8
|
||||
torch::Tensor bias, // FP32
|
||||
float alpha, // FP32
|
||||
float beta // FP32
|
||||
);
|
|
@ -1,3 +0,0 @@
|
|||
from .mha import ColoAttention
|
||||
|
||||
__all__ = ["ColoAttention"]
|
|
@ -1,80 +0,0 @@
|
|||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_ampere_or_better_gpu():
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
properties = torch.cuda.get_device_properties(device)
|
||||
if properties.major >= 8: # Ampere GPUs or newer
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# "Check Ampere GPUs or newer"
|
||||
HAS_FLASH_ATTN = False
|
||||
if is_ampere_or_better_gpu():
|
||||
HAS_FLASH_ATTN = True
|
||||
else:
|
||||
warnings.warn("FlashAttention only supports Ampere GPUs or newer.")
|
||||
HAS_FLASH_ATTN = False
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
||||
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention")
|
||||
HAS_FLASH_ATTN = False
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
pass
|
||||
|
||||
from .utils import SeqLenInfo
|
||||
|
||||
def flash_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q: SeqLenInfo,
|
||||
seq_len_info_kv: SeqLenInfo,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = None,
|
||||
causal: bool = False,
|
||||
padded: bool = False,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch, q_seqlen, nheads, headdim)
|
||||
k: (batch, kv_seqlen, nheads, headdim)
|
||||
v: (batch, kv_seqlen, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
Return:
|
||||
attn_out: (batch, q_seqlen, nheads, headdim).
|
||||
"""
|
||||
if padded:
|
||||
if seq_len_info_kv == None:
|
||||
seq_len_info_kv = seq_len_info_q
|
||||
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
seq_len_info_q.cu_seqlens,
|
||||
seq_len_info_kv.cu_seqlens,
|
||||
seq_len_info_q.max_seqlen,
|
||||
seq_len_info_kv.max_seqlen,
|
||||
dropout_p,
|
||||
scale,
|
||||
causal,
|
||||
)
|
||||
else:
|
||||
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
|
||||
return attn_out
|
|
@ -1,70 +0,0 @@
|
|||
import warnings
|
||||
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
try:
|
||||
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
|
||||
from xformers.ops.fmha.attn_bias import (
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
)
|
||||
|
||||
HAS_MEM_EFF_ATTN = True
|
||||
except ImportError:
|
||||
warnings.warn("please install xformers from https://github.com/facebookresearch/xformers")
|
||||
HAS_MEM_EFF_ATTN = False
|
||||
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
"""
|
||||
A general attention module using the flash attention kernels from xformers:
|
||||
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import SeqLenInfo
|
||||
|
||||
allow_alibi = True
|
||||
for op in MemoryEfficientAttentionCutlassOp:
|
||||
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
|
||||
|
||||
def mem_eff_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seq_len_info_q: SeqLenInfo,
|
||||
seq_len_info_kv: SeqLenInfo,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float = None,
|
||||
causal: bool = False,
|
||||
padded: bool = False,
|
||||
):
|
||||
attn_bias = None
|
||||
if padded: # bert style
|
||||
if not causal:
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
|
||||
else:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
|
||||
elif causal: # gpt style
|
||||
attn_bias = LowerTriangularMask()
|
||||
|
||||
if bias is not None: # alibi / relative position embedding
|
||||
assert allow_alibi, "flash attention with bias is not supported in this system."
|
||||
assert causal, "attention with bias is only supported for causal attention so far."
|
||||
attn_bias = attn_bias.add_bias(bias)
|
||||
|
||||
if padded:
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
|
||||
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
|
||||
|
||||
# shape: (b*s, n, d)
|
||||
if padded:
|
||||
out = out.squeeze(0)
|
||||
|
||||
return out
|
|
@ -1,113 +0,0 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from ..scaled_softmax import AttnMaskType
|
||||
from .flash_attn_2 import HAS_FLASH_ATTN
|
||||
from .mem_eff_attn import HAS_MEM_EFF_ATTN
|
||||
from .utils import Repad, SeqLenInfo, Unpad
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
from .flash_attn_2 import flash_attention
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
from .mem_eff_attn import mem_eff_attention
|
||||
|
||||
|
||||
class ColoAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
|
||||
super().__init__()
|
||||
assert (
|
||||
embed_dim % num_heads == 0
|
||||
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
|
||||
if scale is not None:
|
||||
self.scale = scale
|
||||
else:
|
||||
self.scale = 1 / math.sqrt(embed_dim // num_heads)
|
||||
self.dropout = dropout
|
||||
|
||||
if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN:
|
||||
raise Exception("flash attention can not support!")
|
||||
|
||||
@staticmethod
|
||||
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
return Unpad.apply(tensor, indices)
|
||||
|
||||
@staticmethod
|
||||
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
|
||||
return Repad.apply(tensor, indices, batch_size, seq_len)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_type: Optional[AttnMaskType] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
attn = None
|
||||
if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None:
|
||||
attn = flash_attention
|
||||
else:
|
||||
attn = mem_eff_attention
|
||||
|
||||
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
|
||||
causal = attn_mask_type is not None and attn_mask_type.value > 1
|
||||
|
||||
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
|
||||
# unpad
|
||||
seq_len_info_q = None
|
||||
seq_len_info_kv = None
|
||||
if padded:
|
||||
# bert style, unpad process
|
||||
assert (
|
||||
attn_mask is not None
|
||||
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
|
||||
assert attn_mask.dim() == 2, (
|
||||
"attention mask is supposed to have shape (batch_size, seq_len), "
|
||||
+ f"but got {attn_mask.dim()} dimensions."
|
||||
)
|
||||
|
||||
# bert style
|
||||
if tgt_len == src_len:
|
||||
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
||||
if batch_size > 1:
|
||||
query, key, value = self.unpad(
|
||||
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
|
||||
).unbind(dim=1)
|
||||
else:
|
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
||||
seq_len_info_kv = seq_len_info_q
|
||||
else:
|
||||
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
|
||||
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
|
||||
if batch_size > 1:
|
||||
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
|
||||
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
|
||||
dim=1
|
||||
)
|
||||
else:
|
||||
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
|
||||
|
||||
out = attn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
seq_len_info_q,
|
||||
seq_len_info_kv,
|
||||
dropout_p=self.dropout,
|
||||
scale=self.scale,
|
||||
causal=causal,
|
||||
padded=padded,
|
||||
)
|
||||
|
||||
# repad
|
||||
if padded:
|
||||
if batch_size > 1:
|
||||
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
|
||||
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
|
||||
|
||||
out = rearrange(out, "b s h d -> b s (h d)")
|
||||
return out
|
|
@ -1,82 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
|
||||
class Unpad(torch.autograd.Function):
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
|
||||
ctx.save_for_backward(indices)
|
||||
# [b, s, ...]
|
||||
assert tensor.ndim >= 3
|
||||
ctx.bsz = tensor.shape[0]
|
||||
out = rearrange(tensor, "b s ... -> (b s) ...")
|
||||
ctx.shape = out.shape
|
||||
# [ntokens, ...]
|
||||
return out[indices]
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(indices,) = ctx.saved_tensors
|
||||
# [ntokens, ...]
|
||||
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||
grad[indices] = grad_output
|
||||
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
|
||||
# [b, s, ...]
|
||||
return grad, None
|
||||
|
||||
|
||||
class Repad(torch.autograd.Function):
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
|
||||
ctx.save_for_backward(indices)
|
||||
# [ntokens, ...]
|
||||
tensor = tensor
|
||||
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
|
||||
# [b*s, ...]
|
||||
out[indices] = tensor
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(indices,) = ctx.saved_tensors
|
||||
# [b*s, ...]
|
||||
grad = grad_output[indices]
|
||||
# [ntokens, ...]
|
||||
return grad, None, None, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeqLenInfo:
|
||||
seqlens: Iterable[int] = None
|
||||
indices: torch.Tensor = None
|
||||
max_seqlen: int = None
|
||||
cu_seqlens: torch.Tensor = None
|
||||
|
||||
@staticmethod
|
||||
def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
|
||||
if attn_mask is not None:
|
||||
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
|
||||
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
|
||||
else:
|
||||
batch_size, tgt_len = size[0], size[1]
|
||||
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
|
||||
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
|
||||
max_seqlen = max(seqlens)
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
|
||||
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
|
|
@ -1,338 +0,0 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
def check_config(config):
|
||||
if config.hidden_size % config.nhead != 0:
|
||||
raise Exception("hidden_size % nhead != 0")
|
||||
|
||||
factor = 8 if config.fp16 else 4
|
||||
upbound = factor * 1024 * 4
|
||||
if config.hidden_size > upbound:
|
||||
# as required by ln backward kernel currently
|
||||
raise Exception(f"hidden_size > {upbound}")
|
||||
|
||||
head_dim = config.hidden_size // config.nhead
|
||||
if head_dim % factor != 0:
|
||||
# as required by reshape kernel
|
||||
raise Exception(f"head_dim({head_dim}) % {factor} != 0")
|
||||
|
||||
|
||||
def calc_offset(sizes):
|
||||
offsets = [0]
|
||||
tmp = 0
|
||||
for x in sizes:
|
||||
tmp += x
|
||||
offsets.append(tmp)
|
||||
return offsets
|
||||
|
||||
|
||||
colossal_multihead_attention = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
max_batch_tokens: int # max batch token numbers
|
||||
max_seq_len: int # max sequence length
|
||||
hidden_size: int # size of transformer hidden layers
|
||||
nhead: int # number of heads in attention
|
||||
attn_prob_dropout_ratio: float # attention score dropout ratio
|
||||
hidden_dropout_ratio: float # dropout ration before residual
|
||||
norm_first: bool # norm_first
|
||||
fp16: bool # fp16 precision
|
||||
|
||||
|
||||
class MultiHeadAttention1DFunc(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
input,
|
||||
input_mask,
|
||||
in_proj_weight,
|
||||
in_proj_bias,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
config,
|
||||
):
|
||||
cuda_module = colossal_multihead_attention
|
||||
forward_func = (
|
||||
cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32
|
||||
)
|
||||
if config.fp16:
|
||||
input = input.to(torch.half)
|
||||
input_mask = input_mask.to(torch.half)
|
||||
|
||||
(output,) = forward_func(
|
||||
config.layer_id,
|
||||
input,
|
||||
input_mask,
|
||||
in_proj_weight,
|
||||
in_proj_bias,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
config.training,
|
||||
config.norm_first,
|
||||
)
|
||||
|
||||
if config.is_grad_enabled and config.training:
|
||||
ctx.save_for_backward(
|
||||
output,
|
||||
input,
|
||||
input_mask,
|
||||
in_proj_weight,
|
||||
in_proj_bias,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
)
|
||||
ctx.config = config
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
assert ctx.config.training
|
||||
|
||||
cuda_module = colossal_multihead_attention
|
||||
backward_func = (
|
||||
cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32
|
||||
)
|
||||
|
||||
(
|
||||
output,
|
||||
input,
|
||||
input_mask,
|
||||
in_proj_weight,
|
||||
in_proj_bias,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
) = ctx.saved_tensors
|
||||
|
||||
grad_input = None
|
||||
grad_in_proj_weight = None
|
||||
grad_in_proj_bias = None
|
||||
grad_out_proj_weight = None
|
||||
grad_out_proj_bias = None
|
||||
grad_norm_weight = None
|
||||
grad_norm_bias = None
|
||||
|
||||
if ctx.config.fp16:
|
||||
grad_output = grad_output.to(torch.half)
|
||||
output = output.to(torch.half)
|
||||
input = input.to(torch.half)
|
||||
input_mask = input_mask.to(torch.half)
|
||||
(
|
||||
grad_input,
|
||||
grad_in_proj_weight,
|
||||
grad_in_proj_bias,
|
||||
grad_out_proj_weight,
|
||||
grad_out_proj_bias,
|
||||
grad_norm_weight,
|
||||
grad_norm_bias,
|
||||
) = backward_func(
|
||||
ctx.config.layer_id,
|
||||
grad_output,
|
||||
output,
|
||||
input,
|
||||
input_mask,
|
||||
in_proj_weight,
|
||||
in_proj_bias,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
)
|
||||
|
||||
return (
|
||||
grad_input,
|
||||
None,
|
||||
grad_in_proj_weight,
|
||||
grad_in_proj_bias,
|
||||
grad_out_proj_weight,
|
||||
grad_out_proj_bias,
|
||||
grad_norm_weight,
|
||||
grad_norm_bias,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Initialize the MultiHeadAttention.
|
||||
|
||||
Static variable:
|
||||
|
||||
layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
|
||||
e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
|
||||
|
||||
Arguments:
|
||||
hidden_size: Total dimension of hidden_size.
|
||||
nhead: Number of parallel attention heads.
|
||||
batch_size: Batch Size for one forward
|
||||
max_seq_len: Max length of input sequence
|
||||
dropout: Dropout probability
|
||||
norm_first: perform LayerNorms before attention
|
||||
"""
|
||||
|
||||
layer_id = 0
|
||||
|
||||
def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
self.config = Config(
|
||||
batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16
|
||||
)
|
||||
check_config(self.config)
|
||||
self.pg = pg
|
||||
self.pg_size = 1
|
||||
if self.pg:
|
||||
self.pg_size = pg.size()
|
||||
self.config.layer_id = MultiHeadAttention.layer_id
|
||||
MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1
|
||||
|
||||
# Load cuda modules if needed
|
||||
global colossal_multihead_attention
|
||||
if colossal_multihead_attention is None:
|
||||
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
|
||||
|
||||
multihead_attention = MultiHeadAttnBuilder().load()
|
||||
colossal_multihead_attention = multihead_attention
|
||||
|
||||
# create the layer in cuda kernels.
|
||||
cuda_module = colossal_multihead_attention
|
||||
create_layer_func = (
|
||||
cuda_module.create_multihead_attention_fp16
|
||||
if self.config.fp16
|
||||
else cuda_module.create_multihead_attention_fp32
|
||||
)
|
||||
|
||||
create_layer_func(
|
||||
self.config.layer_id,
|
||||
self.config.max_batch_tokens,
|
||||
self.config.max_seq_len,
|
||||
self.config.hidden_size,
|
||||
self.config.nhead,
|
||||
self.config.attn_prob_dropout_ratio,
|
||||
self.config.hidden_dropout_ratio,
|
||||
self.config.norm_first,
|
||||
self.pg,
|
||||
)
|
||||
|
||||
hs = self.config.hidden_size
|
||||
|
||||
self.precision = torch.float32
|
||||
if self.config.fp16:
|
||||
self.precision = torch.half
|
||||
|
||||
self.hs_per_rank = int(hs / self.pg_size)
|
||||
|
||||
self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs))
|
||||
self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank))
|
||||
self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank))
|
||||
self.out_proj_bias = nn.Parameter(torch.Tensor(hs))
|
||||
self.norm_weight = nn.Parameter(torch.Tensor(hs))
|
||||
self.norm_bias = nn.Parameter(torch.Tensor(hs))
|
||||
|
||||
self.reset_parameters()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def calc_bound(self, w):
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w)
|
||||
bound = 1.0 / math.sqrt(fan_in)
|
||||
return bound
|
||||
|
||||
def reset_parameters(self):
|
||||
hs = self.config.hidden_size
|
||||
|
||||
nn.init.zeros_(self.out_proj_bias)
|
||||
|
||||
nn.init.ones_(self.norm_weight)
|
||||
nn.init.zeros_(self.norm_bias)
|
||||
|
||||
if self.pg_size > 1:
|
||||
rank_in_pg = torch.distributed.get_rank(self.pg)
|
||||
attn_qkvw_global = torch.empty(hs * 3, hs)
|
||||
attn_qkvb_global = torch.empty(hs * 3)
|
||||
nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0))
|
||||
bound = self.calc_bound(attn_qkvw_global)
|
||||
nn.init.uniform_(attn_qkvb_global, -bound, bound)
|
||||
|
||||
attn_qkvw_global = attn_qkvw_global.cuda()
|
||||
attn_qkvb_global = attn_qkvb_global.cuda()
|
||||
torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg)
|
||||
torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg)
|
||||
attn_qkvw_global = attn_qkvw_global.cpu()
|
||||
attn_qkvb_global = attn_qkvb_global.cpu()
|
||||
|
||||
with torch.no_grad():
|
||||
self.in_proj_weight.copy_(
|
||||
attn_qkvw_global.view(3, hs, hs)[
|
||||
:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), :
|
||||
]
|
||||
)
|
||||
self.in_proj_bias.copy_(
|
||||
attn_qkvb_global.view(3, hs)[
|
||||
:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)
|
||||
]
|
||||
)
|
||||
|
||||
attn_ow_global = torch.empty(hs, hs)
|
||||
nn.init.xavier_uniform_(attn_ow_global, 1.0)
|
||||
attn_ow_global = attn_ow_global.cuda()
|
||||
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
|
||||
attn_ow_global = attn_ow_global.cpu()
|
||||
with torch.no_grad():
|
||||
self.out_proj_weight.copy_(
|
||||
attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)]
|
||||
)
|
||||
|
||||
else:
|
||||
attn_qkvw = self.in_proj_weight.view(-1, hs)
|
||||
nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0))
|
||||
bound = self.calc_bound(attn_qkvw)
|
||||
nn.init.uniform_(self.in_proj_bias, -bound, bound)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
return destination
|
||||
|
||||
def forward(self, hidden_states, encoder_padding_mask):
|
||||
self.config.training = self.training
|
||||
self.config.is_grad_enabled = torch.is_grad_enabled()
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()
|
||||
|
||||
bs, sl, dim = hidden_states.size()
|
||||
if bs * sl > self.config.max_batch_tokens:
|
||||
raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
|
||||
if sl > self.config.max_seq_len:
|
||||
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
|
||||
if len(encoder_padding_mask.size()) == 1:
|
||||
assert bs == 1 and sl == encoder_padding_mask.size(0)
|
||||
else:
|
||||
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
|
||||
|
||||
output = MultiHeadAttention1DFunc.apply(
|
||||
hidden_states,
|
||||
encoder_padding_mask,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj_weight,
|
||||
self.out_proj_bias,
|
||||
self.norm_weight,
|
||||
self.norm_bias,
|
||||
self.config,
|
||||
)
|
||||
|
||||
return output.to(self.precision)
|
|
@ -0,0 +1 @@
|
|||
../../extensions
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .bias_dropout_add import bias_dropout_add_fused_train
|
||||
from .bias_gelu import bias_gelu_impl
|
||||
|
@ -46,11 +46,13 @@ def warmup_jit_fusion(
|
|||
):
|
||||
"""Compile JIT functions before the main training steps"""
|
||||
|
||||
embed = Embedding(vocab_size, hidden_size).to(get_current_device())
|
||||
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
|
||||
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device())
|
||||
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
|
||||
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
|
||||
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device())
|
||||
|
||||
x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device())
|
||||
x = torch.randint(
|
||||
vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device()
|
||||
)
|
||||
x = embed(x)
|
||||
y, y_bias = linear_1(x)
|
||||
z, z_bias = linear_2(y)
|
||||
|
@ -58,8 +60,8 @@ def warmup_jit_fusion(
|
|||
# prop and recomputation
|
||||
for bias_grad, input_grad in zip([True, True], [False, True]):
|
||||
for _ in range(10):
|
||||
bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device())
|
||||
input_ = torch.rand_like(y, dtype=dtype, device=get_current_device())
|
||||
bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
bias.requires_grad, input_.requires_grad = bias_grad, input_grad
|
||||
bias_gelu_impl(input_, bias)
|
||||
|
||||
|
@ -69,9 +71,9 @@ def warmup_jit_fusion(
|
|||
# prop and recomputation
|
||||
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
|
||||
for _ in range(10):
|
||||
input_ = torch.rand_like(z, dtype=dtype, device=get_current_device())
|
||||
residual = torch.rand_like(x, dtype=dtype, device=get_current_device())
|
||||
bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device())
|
||||
input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
input_.requires_grad = input_grad
|
||||
bias.requires_grad = bias_grad
|
||||
residual.requires_grad = residual_grad
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
import warnings
|
||||
from typing import List
|
||||
|
||||
from .extensions import (
|
||||
CpuAdamArmExtension,
|
||||
CpuAdamX86Extension,
|
||||
FlashAttentionDaoCudaExtension,
|
||||
FlashAttentionNpuExtension,
|
||||
FlashAttentionXformersCudaExtension,
|
||||
FusedOptimizerCudaExtension,
|
||||
LayerNormCudaExtension,
|
||||
MoeCudaExtension,
|
||||
ScaledMaskedSoftmaxCudaExtension,
|
||||
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
|
||||
)
|
||||
from .extensions.base_extension import _Extension
|
||||
|
||||
__all__ = [
|
||||
"KernelLoader",
|
||||
"CPUAdamLoader",
|
||||
"LayerNormLoader",
|
||||
"MoeLoader",
|
||||
"FusedOptimizerLoader",
|
||||
"ScaledMaskedSoftmaxLoader",
|
||||
"ScaledUpperTriangleMaskedSoftmaxLoader",
|
||||
]
|
||||
|
||||
|
||||
class KernelLoader:
|
||||
"""
|
||||
An abstract class which offers encapsulation to the kernel loading process.
|
||||
|
||||
Usage:
|
||||
kernel_loader = KernelLoader()
|
||||
kernel = kernel_loader.load()
|
||||
"""
|
||||
|
||||
REGISTRY: List[_Extension] = []
|
||||
|
||||
@classmethod
|
||||
def register_extension(cls, extension: _Extension):
|
||||
"""
|
||||
This classmethod is an extension point which allows users to register their customized
|
||||
kernel implementations to the loader.
|
||||
|
||||
Args:
|
||||
extension (_Extension): the extension to be registered.
|
||||
"""
|
||||
cls.REGISTRY.append(extension)
|
||||
|
||||
def load(self, ext_name: str = None):
|
||||
"""
|
||||
Load the kernel according to the current machine.
|
||||
|
||||
Args:
|
||||
ext_name (str): the name of the extension to be loaded. If not specified, the loader
|
||||
will try to look for an kernel available on the current machine.
|
||||
"""
|
||||
exts = [ext_cls() for ext_cls in self.__class__.REGISTRY]
|
||||
|
||||
# look for exts which can be built/loaded on the current machine
|
||||
|
||||
if ext_name:
|
||||
usable_exts = list(filter(lambda ext: ext.name == ext_name, exts))
|
||||
else:
|
||||
usable_exts = []
|
||||
for ext in exts:
|
||||
if ext.is_hardware_available():
|
||||
# make sure the machine is compatible during kernel loading
|
||||
ext.assert_hardware_compatible()
|
||||
usable_exts.append(ext)
|
||||
|
||||
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
|
||||
|
||||
if len(usable_exts) > 1:
|
||||
# if more than one usable kernel is found, we will try to load the kernel with the highest priority
|
||||
usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True)
|
||||
warnings.warn(
|
||||
f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}"
|
||||
)
|
||||
return usable_exts[0].load()
|
||||
|
||||
|
||||
class CPUAdamLoader(KernelLoader):
|
||||
REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension]
|
||||
|
||||
|
||||
class LayerNormLoader(KernelLoader):
|
||||
REGISTRY = [LayerNormCudaExtension]
|
||||
|
||||
|
||||
class MoeLoader(KernelLoader):
|
||||
REGISTRY = [MoeCudaExtension]
|
||||
|
||||
|
||||
class FusedOptimizerLoader(KernelLoader):
|
||||
REGISTRY = [FusedOptimizerCudaExtension]
|
||||
|
||||
|
||||
class ScaledMaskedSoftmaxLoader(KernelLoader):
|
||||
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]
|
||||
|
||||
|
||||
class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
|
||||
REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension]
|
||||
|
||||
|
||||
class FlashAttentionLoader(KernelLoader):
|
||||
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
|
|
@ -1 +0,0 @@
|
|||
../../op_builder
|
|
@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes
|
||||
|
@ -28,7 +28,7 @@ def load_fused_optim():
|
|||
global fused_optim
|
||||
|
||||
if fused_optim is None:
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
fused_optim = FusedOptimizerLoader().load()
|
||||
|
||||
|
||||
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.utils.device import autocast
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.legacy.utils import clip_grad_norm_fp32
|
||||
|
||||
from ._grad_scaler import GradScaler
|
||||
|
||||
autocast = get_accelerator().autocast
|
||||
|
||||
|
||||
class TorchAMPOptimizer(OptimizerWrapper):
|
||||
"""A wrapper class which integrate Pytorch AMP with an optimizer
|
||||
|
|
|
@ -8,9 +8,9 @@ from typing import List, Tuple, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
|
||||
|
||||
|
@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
|
|||
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
|
||||
if isinstance(recv_shapes, torch.Size):
|
||||
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
|
||||
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
||||
buffer_recv = torch.empty(
|
||||
recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
return buffer_recv, recv_split
|
||||
buffer_recv = []
|
||||
for recv_shape in recv_shapes:
|
||||
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
|
||||
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
||||
tensor_recv = torch.empty(
|
||||
recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv, recv_split
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device, synchronize
|
||||
|
||||
|
||||
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
|
||||
|
@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
|
|||
current_rank = gpc.get_global_rank()
|
||||
|
||||
tensor_recv_prev = torch.empty(
|
||||
buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype
|
||||
buffer_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=tensor_send_next.dtype
|
||||
)
|
||||
|
||||
# send to next rank
|
||||
|
@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
|
|||
req.wait()
|
||||
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
synchronize()
|
||||
get_accelerator().synchronize()
|
||||
|
||||
return tensor_recv_prev
|
||||
|
|
|
@ -3,9 +3,9 @@ from typing import List, Tuple, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
|||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
|
||||
if isinstance(obj, torch.Tensor):
|
||||
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
|
@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
|||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
|
||||
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_obj_nums, prev_rank)
|
||||
if recv_obj_nums.item() == 1:
|
||||
|
|
|
@ -6,8 +6,8 @@ from typing import Callable, Iterable
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class BaseSchedule(ABC):
|
||||
|
@ -29,12 +29,12 @@ class BaseSchedule(ABC):
|
|||
def _move_tensor(element):
|
||||
if torch.is_tensor(element):
|
||||
if not element.is_cuda:
|
||||
return element.to(get_current_device()).detach()
|
||||
return element.to(get_accelerator().get_current_device()).detach()
|
||||
return element
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.to(get_current_device())
|
||||
data = data.to(get_accelerator().get_current_device())
|
||||
elif isinstance(data, (list, tuple)):
|
||||
data_to_return = []
|
||||
for element in data:
|
||||
|
|
|
@ -7,12 +7,12 @@ from typing import Callable, List, Tuple, Union
|
|||
import torch.cuda
|
||||
|
||||
import colossalai.legacy.communication as comm
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
|
@ -352,7 +352,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
output_objs = []
|
||||
return_tensors = []
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
# Used for tensor meta information communication
|
||||
|
@ -584,7 +584,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
if not forward_only:
|
||||
output_obj_grads = [[] for _ in range(len(model))]
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
|
|
|
@ -6,10 +6,10 @@ from typing import Iterable, Tuple
|
|||
import torch.cuda
|
||||
|
||||
import colossalai.legacy.communication.p2p_v2 as comm
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.engine import Engine
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ._pipeline_schedule import PipelineSchedule
|
||||
|
||||
|
@ -99,7 +99,7 @@ class PipelineScheduleV2(PipelineSchedule):
|
|||
output_objs = []
|
||||
return_tensors = []
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.context import Config, ConfigException
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
|
||||
|
@ -34,7 +35,6 @@ from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence
|
|||
from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
|
||||
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
|
@ -309,9 +309,9 @@ def initialize(
|
|||
else:
|
||||
if isinstance(model, nn.Module):
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
model.to(get_accelerator().get_current_device())
|
||||
elif isinstance(model, Callable):
|
||||
model = model().to(get_current_device())
|
||||
model = model().to(get_accelerator().get_current_device())
|
||||
|
||||
# optimizer maybe a optimizer_cls
|
||||
if isinstance(optimizer, Callable):
|
||||
|
|
|
@ -3,8 +3,8 @@ from typing import Callable
|
|||
|
||||
from torch import dtype, nn
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.nn import init
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
|
||||
from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
|
||||
|
@ -83,7 +83,7 @@ class Embedding(ColossalaiModule):
|
|||
embed = (
|
||||
nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs)
|
||||
.to(dtype)
|
||||
.to(get_current_device())
|
||||
.to(get_accelerator().get_current_device())
|
||||
)
|
||||
weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||||
elif num_embeddings <= vocab_parallel_limit:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from torch import nn
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from ..parallel_1d import LayerNorm1D
|
||||
from ..parallel_2d import LayerNorm2D
|
||||
|
@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule):
|
|||
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel is None:
|
||||
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
|
||||
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_accelerator().get_current_device())
|
||||
else:
|
||||
norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
||||
super().__init__(norm)
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.kernel import LayerNorm
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication import broadcast
|
||||
from colossalai.legacy.context import ParallelMode, seed
|
||||
from colossalai.legacy.context.parallel_context import global_context as gpc
|
||||
|
@ -22,7 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
|
|||
partition_tensor_parallel_state_dict,
|
||||
)
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.utils.device import get_current_device
|
||||
from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..colossalai_layer._utils import ColossalaiModule
|
||||
|
@ -221,7 +221,7 @@ class Classifier1D(ParallelLayer):
|
|||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
|
@ -357,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
|
@ -499,7 +499,7 @@ class Linear1D_Col(ParallelLayer):
|
|||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
||||
|
||||
if bias:
|
||||
|
@ -638,7 +638,7 @@ class Linear1D_Row(ParallelLayer):
|
|||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
|
@ -802,7 +802,9 @@ class Embedding1D(ParallelLayer):
|
|||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
@ -912,7 +914,11 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
|
|
@ -5,10 +5,10 @@ import torch.distributed as dist
|
|||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def matmul_2d(
|
||||
|
@ -250,7 +250,7 @@ class Matmul_AB_2D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[-1])
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
@ -399,7 +399,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[0])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
@ -556,7 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[-1], B.shape[-1])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication import broadcast
|
||||
from colossalai.legacy.context import ParallelMode, seed
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
@ -18,7 +19,6 @@ from colossalai.legacy.utils.checkpointing import (
|
|||
partition_tensor_parallel_state_dict,
|
||||
)
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
|
@ -82,7 +82,7 @@ class Linear2D(ParallelLayer):
|
|||
self.hidden_size_per_partition = divide(self.out_features, self.summa_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
|
||||
)
|
||||
|
@ -259,7 +259,7 @@ class LayerNorm2D(ParallelLayer):
|
|||
self.partitioned_partition = divide(normalized_shape, self.summa_dim**2)
|
||||
|
||||
# create parameters
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
|
||||
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
if bias:
|
||||
|
@ -438,18 +438,24 @@ class PatchEmbedding2D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
(self.embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.cls_token = Parameter(
|
||||
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
(1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
self.pos_embed = Parameter(
|
||||
torch.zeros(
|
||||
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype
|
||||
(1, self.num_patches + 1, self.embed_size_per_partition),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -619,7 +625,9 @@ class Embedding2D(ParallelLayer):
|
|||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
@ -758,7 +766,7 @@ class VocabParallelEmbedding2D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
@ -895,11 +903,18 @@ class Classifier2D(ParallelLayer):
|
|||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
self.num_classes,
|
||||
self.input_size_per_partition,
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -1052,7 +1067,7 @@ class VocabParallelClassifier2D(ParallelLayer):
|
|||
self.output_size_per_partition = divide(num_classes, self.summa_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
|
|
|
@ -5,10 +5,10 @@ import torch.distributed as dist
|
|||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_parallel_group(parallel_mode: ParallelMode):
|
||||
|
@ -205,7 +205,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[-1])
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
@ -362,7 +362,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[0])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
@ -527,7 +527,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[-1], B.shape[-1])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
@ -661,7 +661,9 @@ class _Add_Bias_2p5D(torch.autograd.Function):
|
|||
if row_rank == 0:
|
||||
bias_temp = bias.clone()
|
||||
else:
|
||||
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
|
||||
bias_temp = torch.zeros(
|
||||
output_size_per_partition, dtype=bias.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
src_rank = (
|
||||
col_rank
|
||||
+ dep_rank * tesseract_dim**2
|
||||
|
@ -984,7 +986,7 @@ class SplitFirst(torch.autograd.Function):
|
|||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
|
||||
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
|
||||
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_accelerator().get_current_device())
|
||||
dist.all_gather(
|
||||
list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode)
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication import broadcast
|
||||
from colossalai.legacy.context import ParallelMode, seed
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
@ -19,7 +20,6 @@ from colossalai.legacy.utils.checkpointing import (
|
|||
partition_tensor_parallel_state_dict,
|
||||
)
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
|
@ -84,7 +84,7 @@ class Linear2p5D(ParallelLayer):
|
|||
self.hidden_size_per_partition = divide(out_features, self.tesseract_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
|
||||
)
|
||||
|
@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
|
|||
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
|
||||
|
||||
# create parameters
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
|
||||
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
if bias:
|
||||
|
@ -451,18 +451,24 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
(self.embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.cls_token = Parameter(
|
||||
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
(1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
self.pos_embed = Parameter(
|
||||
torch.zeros(
|
||||
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype
|
||||
(1, self.num_patches + 1, self.embed_size_per_partition),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -632,7 +638,9 @@ class Embedding2p5D(ParallelLayer):
|
|||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
@ -772,7 +780,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
@ -910,11 +918,18 @@ class Classifier2p5D(ParallelLayer):
|
|||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
self.num_classes,
|
||||
self.input_size_per_partition,
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -1068,7 +1083,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
|
|||
self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication import all_reduce, broadcast
|
||||
from colossalai.legacy.constants import (
|
||||
INPUT_GROUP_3D,
|
||||
|
@ -27,7 +28,6 @@ from colossalai.legacy.utils.checkpointing import (
|
|||
partition_tensor_parallel_state_dict,
|
||||
)
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import (
|
||||
|
@ -69,11 +69,13 @@ class LayerNorm3D(ParallelLayer):
|
|||
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
@ -202,13 +204,15 @@ class Linear3D(ParallelLayer):
|
|||
torch.empty(
|
||||
self.in_features_per_partition,
|
||||
self.out_features_per_partition,
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
@ -380,11 +384,18 @@ class Classifier3D(ParallelLayer):
|
|||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
self.num_classes,
|
||||
self.in_features_per_partition,
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -523,14 +534,16 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
torch.empty(
|
||||
self.out_features_per_partition,
|
||||
self.in_features_per_partition,
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
@ -705,16 +718,24 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(
|
||||
(embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype
|
||||
(embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
(1, self.num_patches + 1, embed_size_per_partition),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
|
@ -880,7 +901,9 @@ class Embedding3D(ParallelLayer):
|
|||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
@ -1019,7 +1042,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue